mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
1 Commits
bot_nit
...
assistant_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d50b160c7 |
@@ -24,8 +24,6 @@ env:
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -73,7 +73,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
@@ -8,7 +8,6 @@ from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
@@ -36,18 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -1,40 +0,0 @@
|
||||
"""non-nullbale slack bot id in channel config
|
||||
|
||||
Revision ID: f7a894b06d02
|
||||
Revises: 9f696734098f
|
||||
Create Date: 2024-12-06 12:55:42.845723
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7a894b06d02"
|
||||
down_revision = "9f696734098f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete all rows with null slack_bot_id
|
||||
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
|
||||
|
||||
# Make slack_bot_id non-nullable
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Make slack_bot_id nullable again
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -18,11 +18,6 @@ class ExternalAccess:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
@@ -58,6 +58,7 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -86,7 +87,6 @@ from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -99,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -131,12 +136,11 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
# For other auth types, if the user is authenticated it's assumed that
|
||||
# the user is already verified via the external IDP
|
||||
return False
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -333,16 +332,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock: RedisLock = sender.primary_worker_lock
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception:
|
||||
logger.exception("Failed to release primary worker lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to check if primary worker lock is owned")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -39,6 +38,7 @@ from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@@ -116,13 +116,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
|
||||
# set thread_local=False since we don't control what thread the periodic task might
|
||||
# reacquire the lock with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
@@ -231,7 +227,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
lock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
||||
@@ -2,55 +2,54 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"task": "check_for_external_group_sync",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -29,7 +28,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -18,11 +18,9 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -84,7 +82,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
name="check_for_doc_permissions_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -166,7 +164,7 @@ def try_creating_permissions_sync_task(
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -193,7 +191,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
name="connector_permission_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -219,7 +217,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
@@ -263,12 +261,7 @@ def connector_permission_sync_generator_task(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
@@ -293,7 +286,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
name="update_external_document_permissions_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
@@ -304,8 +297,6 @@ def update_external_document_permissions_task(
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
@@ -314,28 +305,18 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
# Then we build the update requests to update vespa
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -32,14 +31,10 @@ from danswer.redis.redis_connector_ext_group_sync import (
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -90,7 +85,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
name="check_for_external_group_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -111,22 +106,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session, source, only_sync=True
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.id != cc_pair_to_remove.id
|
||||
]
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
@@ -182,7 +161,7 @@ def try_creating_external_group_sync_task(
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -212,7 +191,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
name="connector_external_group_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
|
||||
@@ -23,7 +23,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
@@ -157,7 +156,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
@@ -487,7 +486,7 @@ def try_creating_indexing_task(
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -525,10 +524,7 @@ def try_creating_indexing_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
|
||||
)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
@@ -584,64 +580,39 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
"Indexing proxy - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
finally:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
continue
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -789,12 +760,9 @@ def connector_indexing_task(
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -13,13 +13,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
|
||||
@@ -8,7 +8,6 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -21,7 +20,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
@@ -77,7 +75,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -186,7 +184,7 @@ def try_creating_prune_generator_task(
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -211,7 +209,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
name="connector_pruning_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -240,12 +238,9 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -9,7 +9,6 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
@@ -32,7 +31,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
@@ -81,7 +80,7 @@ logger = setup_logger()
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -655,28 +654,24 @@ def monitor_ccpair_indexing_taskset(
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
result_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
result_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"result_state={result_state} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
@@ -708,7 +703,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
@@ -819,7 +814,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
|
||||
@@ -2,79 +2,20 @@ import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
prompt: Prompt | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
@@ -90,49 +31,9 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -295,71 +196,3 @@ def extract_headers(
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
@@ -1,30 +1,17 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.context.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
@@ -38,7 +25,6 @@ class LlmDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@@ -131,6 +117,20 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
@@ -146,20 +146,14 @@ class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class ChatDanswerBotResponse(BaseModel):
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
@@ -171,41 +165,9 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
@@ -221,109 +183,3 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
@@ -6,24 +6,16 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.answer import Answer
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
@@ -62,11 +54,16 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.utils import load_all_chat_files
|
||||
from danswer.file_store.utils import save_files
|
||||
from danswer.file_store.utils import save_files_from_urls
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@@ -105,7 +102,6 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -117,10 +113,7 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -263,7 +256,6 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| DanswerContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -294,8 +286,6 @@ def stream_chat_message_objects(
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -303,7 +293,6 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
@@ -333,31 +322,17 @@ def stream_chat_message_objects(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
||||
@@ -580,34 +555,19 @@ def stream_chat_message_objects(
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].system_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
include_citations=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].include_citations,
|
||||
)
|
||||
elif prompt_override:
|
||||
if not final_msg.prompt:
|
||||
raise ValueError(
|
||||
"Prompt override cannot be applied, no base prompt found."
|
||||
)
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=prompt_override,
|
||||
)
|
||||
elif final_msg.prompt:
|
||||
prompt_config = PromptConfig.from_model(final_msg.prompt)
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
@@ -627,13 +587,11 @@ def stream_chat_message_objects(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
bypass_acl=bypass_acl,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
@@ -680,8 +638,7 @@ def stream_chat_message_objects(
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
# any files to associate with the AI message e.g. dall-e generated images
|
||||
ai_message_files = []
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
@@ -736,14 +693,8 @@ def stream_chat_message_objects(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
@@ -769,19 +720,15 @@ def stream_chat_message_objects(
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
@@ -790,8 +737,6 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(DanswerContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
@@ -831,8 +776,7 @@ def stream_chat_message_objects(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=answer.citations)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
@@ -901,30 +845,3 @@ def stream_chat_message(
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_slack(
|
||||
packets: ChatPacketStream,
|
||||
) -> ChatDanswerBotResponse:
|
||||
response = ChatDanswerBotResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
@@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -81,14 +84,7 @@ OAUTH_CLIENT_SECRET = (
|
||||
or ""
|
||||
)
|
||||
|
||||
# for future OAuth connector support
|
||||
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -122,8 +118,6 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
# the number of times to try and connect to vespa on startup before giving up
|
||||
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
|
||||
|
||||
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
|
||||
|
||||
@@ -314,22 +308,6 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
# The current state of affairs:
|
||||
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
|
||||
# no API retrieves the user's timezone
|
||||
# All data is returned in UTC, so we can't derive the user's timezone from that
|
||||
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -515,6 +493,10 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
@@ -528,6 +510,3 @@ API_KEY_HASH_ROUNDS = (
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
@@ -31,8 +31,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -261,32 +259,6 @@ class DanswerCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
class DanswerCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
|
||||
"update_external_document_permissions_task"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
@@ -4,8 +4,11 @@ import os
|
||||
# Danswer Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
)
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
@@ -44,6 +47,17 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
# if set, will default DanswerBot to use quotes and reference documents
|
||||
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
|
||||
@@ -2,8 +2,6 @@ import json
|
||||
import os
|
||||
|
||||
|
||||
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
|
||||
|
||||
# if specified, will pass through request headers to the call to API calls made by custom tools
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
|
||||
|
||||
@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
|
||||
- Load Connector:
|
||||
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll Connector:
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Slim Connector:
|
||||
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
|
||||
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
|
||||
- This is used by our pruning job which removes old documents from the index.
|
||||
- The optional start and end datetimes can be ignored.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
- Currently not used by the background job, this exists for future design purposes.
|
||||
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
|
||||
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
|
||||
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
|
||||
|
||||
For implementing a Slim Connector, refer to the comments in this PR:
|
||||
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
|
||||
|
||||
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
|
||||
|
||||
|
||||
#### Implementing the new Connector
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -15,7 +13,6 @@ from danswer.connectors.confluence.utils import attachment_to_content
|
||||
from danswer.connectors.confluence.utils import build_confluence_document_id
|
||||
from danswer.connectors.confluence.utils import datetime_from_string
|
||||
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from danswer.connectors.confluence.utils import validate_attachment_filetype
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -72,7 +69,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
@@ -108,8 +104,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -210,14 +204,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
@@ -250,10 +242,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
@@ -277,11 +269,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
):
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
@@ -291,7 +281,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
@@ -301,21 +291,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
if not attachment_restrictions:
|
||||
attachment_restrictions = page_restrictions
|
||||
|
||||
attachment_space_key = attachment.get("space", {}).get("key")
|
||||
if not attachment_space_key:
|
||||
attachment_space_key = page_space_key
|
||||
|
||||
attachment_perm_sync_data = {
|
||||
"restrictions": attachment_restrictions or {},
|
||||
"space_key": attachment_space_key,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
@@ -323,7 +298,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
|
||||
@@ -134,32 +134,6 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -332,13 +306,6 @@ def _validate_connector_configuration(
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
|
||||
@@ -32,11 +32,7 @@ def get_user_email_from_username__server(
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
# For now, we'll just return a string that indicates failure
|
||||
# We may want to revert to returning None in the future
|
||||
# email = None
|
||||
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
@@ -177,23 +173,19 @@ def extract_text_from_confluence_html(
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
@@ -249,7 +241,7 @@ def build_confluence_document_id(
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
def extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
|
||||
@@ -12,15 +12,12 @@ from dateutil import parser
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -31,8 +28,6 @@ logger = setup_logger()
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
|
||||
@@ -163,26 +158,21 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
slab_bot_token: str | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
self._slab_bot_token: str | None = None
|
||||
self.slab_bot_token = slab_bot_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._slab_bot_token = credentials["slab_bot_token"]
|
||||
self.slab_bot_token = credentials["slab_bot_token"]
|
||||
return None
|
||||
|
||||
@property
|
||||
def slab_bot_token(self) -> str:
|
||||
if self._slab_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Slab")
|
||||
return self._slab_bot_token
|
||||
|
||||
def _iterate_posts(
|
||||
self, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -237,21 +227,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
yield from self._iterate_posts(
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=post_id,
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -5,11 +5,7 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.chat.prune_and_merge import _merge_sections
|
||||
from danswer.chat.prune_and_merge import ChunkRange
|
||||
from danswer.chat.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
@@ -31,6 +27,10 @@ from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prune_and_merge import _merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import ChunkRange
|
||||
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -16,7 +16,7 @@ from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.blocks.basic_components import MarkdownTextObject
|
||||
from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -40,7 +40,10 @@ from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.db.chat import get_chat_session_by_message_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
_MAX_BLURB_LEN = 45
|
||||
|
||||
@@ -324,7 +327,7 @@ def _build_sources_blocks(
|
||||
|
||||
|
||||
def _priority_ordered_documents_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -347,7 +350,7 @@ def _priority_ordered_documents_blocks(
|
||||
|
||||
|
||||
def _build_citations_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -366,8 +369,51 @@ def _build_citations_blocks(
|
||||
return citations_block
|
||||
|
||||
|
||||
def _build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
doc_to_quotes: dict[str, list[str]] = {}
|
||||
doc_to_link: dict[str, str] = {}
|
||||
doc_to_sem_id: dict[str, str] = {}
|
||||
for q in quotes:
|
||||
quote = q.quote
|
||||
doc_id = q.document_id
|
||||
doc_link = q.link
|
||||
doc_name = q.semantic_identifier
|
||||
if doc_link and doc_name and doc_id and quote:
|
||||
if doc_id not in doc_to_quotes:
|
||||
doc_to_quotes[doc_id] = [quote]
|
||||
doc_to_link[doc_id] = doc_link
|
||||
doc_to_sem_id[doc_id] = (
|
||||
doc_name
|
||||
if q.source_type != DocumentSource.SLACK.value
|
||||
else "#" + doc_name
|
||||
)
|
||||
else:
|
||||
doc_to_quotes[doc_id].append(quote)
|
||||
|
||||
for doc_id, quote_strs in doc_to_quotes.items():
|
||||
quotes_str_clean = [
|
||||
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
|
||||
]
|
||||
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
|
||||
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
|
||||
link = doc_to_link[doc_id]
|
||||
sem_id = doc_to_sem_id[doc_id]
|
||||
quote_lines.append(
|
||||
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
|
||||
)
|
||||
|
||||
if not doc_to_quotes:
|
||||
return []
|
||||
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
skip_quotes: bool = False,
|
||||
process_message_for_citations: bool = False,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
@@ -376,10 +422,13 @@ def _build_qa_response_blocks(
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
quotes = answer.quotes.quotes if answer.quotes else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
filter_block: Block | None = None
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
@@ -422,6 +471,16 @@ def _build_qa_response_blocks(
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
if quotes:
|
||||
quotes_blocks = _build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `_build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
|
||||
)
|
||||
]
|
||||
|
||||
response_blocks: list[Block] = []
|
||||
|
||||
@@ -430,6 +489,9 @@ def _build_qa_response_blocks(
|
||||
|
||||
response_blocks.extend(answer_blocks)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
|
||||
return response_blocks
|
||||
|
||||
|
||||
@@ -505,9 +567,10 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
answer: OneShotQAResponse,
|
||||
persona: Persona | None,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
@@ -524,6 +587,7 @@ def build_slack_response_blocks(
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
@@ -553,7 +617,8 @@ def build_slack_response_blocks(
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if use_citations and answer.citations:
|
||||
if use_citations:
|
||||
# if citations are enabled, only show cited documents
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
@@ -572,5 +637,4 @@ def build_slack_response_blocks(
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
)
|
||||
|
||||
return all_blocks
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -8,36 +9,46 @@ from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.chat.chat_utils import prepare_chat_message_request
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.process_message import gather_stream_for_slack
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.handlers.utils import slackify_message_thread
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
@@ -72,14 +83,16 @@ def handle_regular_answer(
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
@@ -89,18 +102,9 @@ def handle_regular_answer(
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
prompt = None
|
||||
# If no persona is specified, use the default search based persona
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
else:
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
@@ -108,26 +112,6 @@ def handle_regular_answer(
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# TODO: Add in support for Slack to truncate messages based on max LLM context
|
||||
# llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=llm.config.model_name,
|
||||
# provider_type=llm.config.model_provider,
|
||||
# )
|
||||
|
||||
# # In cases of threads, split the available tokens between docs and thread context
|
||||
# input_tokens = get_max_input_tokens(
|
||||
# model_name=llm.config.model_name,
|
||||
# model_provider=llm.config.model_provider,
|
||||
# )
|
||||
# max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
# combined_message = combine_message_thread(
|
||||
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
|
||||
# )
|
||||
|
||||
combined_message = slackify_message_thread(messages)
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_channel_config
|
||||
@@ -138,6 +122,13 @@ def handle_regular_answer(
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_channel_config is None
|
||||
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
@@ -150,23 +141,75 @@ def handle_regular_answer(
|
||||
backoff=2,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, danswer_user: User | None
|
||||
) -> ChatDanswerBotResponse:
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=danswer_user,
|
||||
if len(new_message_request.messages) > 1:
|
||||
if new_message_request.persona_config:
|
||||
raise RuntimeError("Slack bot does not support persona config")
|
||||
elif new_message_request.persona_id is not None:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No persona id provided, this should never happen."
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
answer = gather_stream_for_slack(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
@@ -196,24 +239,26 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
user=user,
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
multilingual_query_expansion=saved_search_settings.multilingual_expansion
|
||||
if saved_search_settings
|
||||
else None,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
answer = _get_slack_answer(
|
||||
new_message_request=answer_request, danswer_user=user
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -314,7 +359,7 @@ def handle_regular_answer(
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{combined_message}' - no documents found"
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -335,18 +380,18 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_if_citations = (
|
||||
only_respond_with_citations_or_quotes = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
if (
|
||||
only_respond_if_citations
|
||||
and not answer.citations
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -364,8 +409,9 @@ def handle_regular_answer(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
persona=persona,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True, # No longer supporting quotes
|
||||
use_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,33 +1,8 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
# Note: this does not handle extremely long threads, every message will be included
|
||||
# with weaker LLMs, this could cause issues with exceeeding the token limit
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"AI said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
|
||||
@@ -19,8 +19,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.app_configs import DEV_MODE
|
||||
from danswer.configs.app_configs import POD_NAME
|
||||
from danswer.configs.app_configs import POD_NAMESPACE
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -76,6 +74,7 @@ from danswer.db.slack_bot import fetch_slack_bots
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -251,7 +250,7 @@ class SlackbotHandler:
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
if not acquired and not DEV_MODE:
|
||||
if not acquired:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
|
||||
@@ -30,13 +30,13 @@ from danswer.configs.danswerbot_configs import (
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import FeedbackVisibility
|
||||
from danswer.danswerbot.slack.models import ThreadMessage
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
|
||||
@@ -145,10 +145,16 @@ def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
only_one_shot: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if only_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
@@ -220,11 +226,12 @@ def delete_messages_and_files_from_chat_session(
|
||||
|
||||
def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str | None,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
@@ -234,6 +241,7 @@ def create_chat_session(
|
||||
description=description,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -279,6 +287,8 @@ def duplicate_chat_session_for_user_from_slack(
|
||||
description="",
|
||||
llm_override=chat_session.llm_override,
|
||||
prompt_override=chat_session.prompt_override,
|
||||
# Chat sessions from Slack should put people in the chat UI, not the search
|
||||
one_shot=False,
|
||||
# Chat is in UI now so this is false
|
||||
danswerbot_flow=False,
|
||||
# Maybe we want this in the future to track if it was created from Slack
|
||||
|
||||
@@ -248,6 +248,7 @@ def create_credential(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return credential
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -427,9 +426,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
@@ -963,8 +964,9 @@ class ChatSession(Base):
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# This chat created by DanswerBot
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
danswerbot_flow: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Only ever set to True if system is set to not hard-delete chats
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -1486,13 +1488,16 @@ class ChannelConfig(TypedDict):
|
||||
show_continue_in_web_ui: NotRequired[bool] # defaults to False
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
|
||||
|
||||
class SlackChannelConfig(Base):
|
||||
__tablename__ = "slack_channel_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
slack_bot_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("slack_bot.id"), nullable=False
|
||||
)
|
||||
slack_bot_id: Mapped[int] = mapped_column(ForeignKey("slack_bot.id"), nullable=True)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
@@ -1500,6 +1505,9 @@ class SlackChannelConfig(Base):
|
||||
channel_config: Mapped[ChannelConfig] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False
|
||||
)
|
||||
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
|
||||
@@ -415,6 +415,9 @@ def upsert_prompt(
|
||||
return prompt
|
||||
|
||||
|
||||
# NOTE: This operation cannot update persona configuration options that
|
||||
# are core to the persona, such as its display priority and
|
||||
# whether or not the assistant is a built-in / default assistant
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -446,12 +449,6 @@ def upsert_persona(
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
"""
|
||||
NOTE: This operation cannot update persona configuration options that
|
||||
are core to the persona, such as its display priority and
|
||||
whether or not the assistant is a built-in / default assistant
|
||||
"""
|
||||
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
else:
|
||||
@@ -489,8 +486,6 @@ def upsert_persona(
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
# Built-in personas can only be updated through YAML configuration.
|
||||
# This ensures that core system personas are not modified unintentionally.
|
||||
if persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
@@ -499,9 +494,6 @@ def upsert_persona(
|
||||
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
persona.name = name
|
||||
persona.description = description
|
||||
persona.num_chunks = num_chunks
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
@@ -82,6 +83,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -113,6 +115,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id=slack_bot_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=response_type,
|
||||
standard_answer_categories=existing_standard_answer_categories,
|
||||
enable_auto_filters=enable_auto_filters,
|
||||
)
|
||||
@@ -127,6 +130,7 @@ def update_slack_channel_config(
|
||||
slack_channel_config_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -166,6 +170,7 @@ def update_slack_channel_config(
|
||||
# will encounter `violates foreign key constraint` errors
|
||||
slack_channel_config.persona_id = persona_id
|
||||
slack_channel_config.channel_config = channel_config
|
||||
slack_channel_config.response_type = response_type
|
||||
slack_channel_config.standard_answer_categories = list(
|
||||
existing_standard_answer_categories
|
||||
)
|
||||
|
||||
@@ -4,8 +4,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
# Not to be confused with the UUID generated for this chunk which is called documentid by default
|
||||
field document_id type string {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
rank: filter
|
||||
}
|
||||
field chunk_id type int {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -6,7 +6,6 @@ import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from email.parser import Parser as EmailParser
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
@@ -16,17 +15,13 @@ import chardet
|
||||
import docx # type: ignore
|
||||
import openpyxl # type: ignore
|
||||
import pptx # type: ignore
|
||||
from docx import Document
|
||||
from fastapi import UploadFile
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
from danswer.configs.constants import DANSWER_METADATA_FILENAME
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.file_store.file_store import FileStore
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -380,35 +375,3 @@ def extract_file_text(
|
||||
) from e
|
||||
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
def convert_docx_to_txt(
|
||||
file: UploadFile, file_store: FileStore, file_path: str
|
||||
) -> None:
|
||||
file.file.seek(0)
|
||||
docx_content = file.file.read()
|
||||
doc = Document(BytesIO(docx_content))
|
||||
|
||||
# Extract text from the document
|
||||
full_text = []
|
||||
for para in doc.paragraphs:
|
||||
full_text.append(para.text)
|
||||
|
||||
# Join the extracted text
|
||||
text_content = "\n".join(full_text)
|
||||
|
||||
txt_file_path = docx_to_txt_filename(file_path)
|
||||
file_store.save_file(
|
||||
file_name=txt_file_path,
|
||||
content=BytesIO(text_content.encode("utf-8")),
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
file_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
def docx_to_txt_filename(file_path: str) -> str:
|
||||
"""
|
||||
Convert a .docx file path to its corresponding .txt file path.
|
||||
"""
|
||||
return file_path.rsplit(".", 1)[0] + ".txt"
|
||||
|
||||
@@ -59,12 +59,6 @@ class FileStore(ABC):
|
||||
Contents of the file and metadata dict
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read_file_record(self, file_name: str) -> PGFileStore:
|
||||
"""
|
||||
Read the file record by the name
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -13,8 +13,8 @@ from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.b64 import get_image_type
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
@@ -75,58 +75,11 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
return unique_id
|
||||
|
||||
|
||||
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unique_id = str(uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=unique_id,
|
||||
content=BytesIO(base64.b64decode(base64_string)),
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
)
|
||||
return unique_id
|
||||
def save_files_from_urls(urls: list[str]) -> list[str]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
|
||||
def save_file(
|
||||
tenant_id: str,
|
||||
url: str | None = None,
|
||||
base64_data: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file from either a URL or base64 encoded string.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to save the file under
|
||||
url: URL to download file from
|
||||
base64_data: Base64 encoded file data
|
||||
|
||||
Returns:
|
||||
The unique ID of the saved file
|
||||
|
||||
Raises:
|
||||
ValueError: If neither url nor base64_data is provided, or if both are provided
|
||||
"""
|
||||
if url is not None and base64_data is not None:
|
||||
raise ValueError("Cannot specify both url and base64_data")
|
||||
|
||||
if url is not None:
|
||||
return save_file_from_url(url, tenant_id)
|
||||
elif base64_data is not None:
|
||||
return save_file_from_base64(base64_data, tenant_id)
|
||||
else:
|
||||
raise ValueError("Must specify either url or base64_data")
|
||||
|
||||
|
||||
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
|
||||
# NOTE: be explicit about typing so that if we change things, we get notified
|
||||
funcs: list[
|
||||
tuple[
|
||||
Callable[[str, str | None, str | None], str],
|
||||
tuple[str, str | None, str | None],
|
||||
]
|
||||
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
|
||||
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
|
||||
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
|
||||
(save_file_from_url, (url, tenant_id)) for url in urls
|
||||
]
|
||||
|
||||
# Must pass in tenant_id here, since this is called by multithreading
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
||||
@@ -6,27 +6,33 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.build import default_build_system_message
|
||||
from danswer.chat.prompt_builder.build import default_build_user_message
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.chat.stream_processing.utils import map_document_id_order
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
@@ -208,23 +214,18 @@ class Answer:
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
)
|
||||
else:
|
||||
raise ValueError("No answer style config provided")
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
@@ -1,22 +1,60 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
163
backend/danswer/llm/answering/models.py
Normal file
163
backend/danswer/llm/answering/models.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
@@ -4,26 +4,20 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
@@ -145,15 +139,3 @@ class AnswerPromptBuilder:
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -2,12 +2,12 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
@@ -1,10 +1,10 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
20
backend/danswer/llm/answering/prompts/utils.py
Normal file
20
backend/danswer/llm/answering/prompts/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
@@ -5,16 +5,16 @@ from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ContextualPruningConfig
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
@@ -3,11 +3,16 @@ from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.llm_response_handler import ResponsePart
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -65,29 +70,28 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# yield from self.quotes_processor.process_token(content)
|
||||
yield from self.quotes_processor.process_token(content)
|
||||
@@ -4,8 +4,8 @@ from collections.abc import Generator
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# THIS IS NO LONGER IN USE
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
@@ -6,10 +5,11 @@ from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
@@ -26,20 +26,6 @@ logger = setup_logger()
|
||||
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
@@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -14,11 +13,8 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
@@ -36,15 +32,11 @@ def get_main_llm_from_tuple(
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
persona: Persona,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if persona is None:
|
||||
logger.warning("No persona provided, using default LLMs")
|
||||
return get_default_llms()
|
||||
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
@@ -79,7 +71,6 @@ def get_llms_for_persona(
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
@@ -137,13 +128,11 @@ def get_llm(
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
@@ -5,6 +5,8 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
import pandas as pd
|
||||
@@ -34,15 +36,17 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.utils.b64 import get_image_type
|
||||
from danswer.utils.b64 import get_image_type_from_bytes
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -100,6 +104,39 @@ def litellm_exception_to_error_msg(
|
||||
return error_msg
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
|
||||
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
@@ -153,7 +190,6 @@ def build_content_with_imgs(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
@@ -166,7 +202,6 @@ def build_content_with_imgs(
|
||||
)
|
||||
|
||||
img_urls = img_urls or []
|
||||
b64_imgs = b64_imgs or []
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
@@ -185,22 +220,11 @@ def build_content_with_imgs(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": (
|
||||
f"data:{get_image_type_from_bytes(file.content)};"
|
||||
f"base64,{file.to_base64()}"
|
||||
),
|
||||
"url": f"data:image/jpeg;base64,{file.to_base64()}",
|
||||
},
|
||||
}
|
||||
for file in img_files
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
|
||||
},
|
||||
}
|
||||
for b64_img in b64_imgs
|
||||
for file in files
|
||||
if file.file_type == "image"
|
||||
]
|
||||
+ [
|
||||
{
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import create_danswer_oauth_router
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
@@ -91,7 +92,6 @@ from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.setup import setup_multitenant_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -105,6 +105,7 @@ from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -205,7 +206,7 @@ def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
|
||||
if isinstance(exc, BasicAuthenticationError):
|
||||
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
|
||||
logger.warning(f"Authentication failed: {str(exc)}")
|
||||
logger.error(f"Authentication failed: {str(exc)}")
|
||||
|
||||
elif status_code >= 400:
|
||||
error_msg = f"{str(exc)}\n"
|
||||
|
||||
0
backend/danswer/one_shot_answer/__init__.py
Normal file
0
backend/danswer/one_shot_answer/__init__.py
Normal file
456
backend/danswer/one_shot_answer/answer_question.py
Normal file
456
backend/danswer/one_shot_answer/answer_question.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import reorganize_citations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import DocumentRelevance
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import RelevanceAnalysis
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.models import RerankMetricsContainer
|
||||
from danswer.context.search.models import RetrievalMetricsContainer
|
||||
from danswer.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.context.search.utils import dedupe_documents
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.one_shot_answer.qa_utils import slackify_message_thread
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
AnswerObjectIterator = Iterator[
|
||||
QueryRephrase
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| DanswerContexts
|
||||
| StreamingError
|
||||
| ChatMessageDetail
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| DocumentRelevance
|
||||
]
|
||||
|
||||
|
||||
def stream_answer_objects(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
# These need to be passed in because in Web UI one shot flow,
|
||||
# we can have much more document as there is no history.
|
||||
# For Slack flow, we need to save more tokens for the thread context
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> AnswerObjectIterator:
|
||||
"""Streams in order:
|
||||
1. [always] Retrieved documents, stops flow if nothing is found
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed DanswerAnswerPiece and DanswerQuotes at the end
|
||||
or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
user_id = user.id if user is not None else None
|
||||
query_msg = query_req.messages[-1]
|
||||
history = query_req.messages[:-1]
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # One shot queries don't need naming as it's never displayed
|
||||
user_id=user_id,
|
||||
persona_id=query_req.persona_id,
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session.id)}
|
||||
)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
|
||||
if query_req.persona_config is not None:
|
||||
temporary_persona = fetch_ee_implementation_or_noop(
|
||||
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
|
||||
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
try:
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona, long_term_logger=long_term_logger
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Failed to initialize LLMs for persona '{persona.name}': {str(e)}"
|
||||
)
|
||||
if "No LLM provider" in str(e):
|
||||
raise ValueError(
|
||||
"Please configure a Generative AI model to use this feature."
|
||||
) from e
|
||||
raise ValueError(
|
||||
"Failed to initialize the AI model. Please check your configuration and try again."
|
||||
) from e
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Create a chat session which will just store the root message, the query, and the AI response
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = query_req.query_override or thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
# Given back ahead of the documents for latency reasons
|
||||
# In chat flow it's given back along with the documents
|
||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||
|
||||
prompt = None
|
||||
if query_req.prompt_id is not None:
|
||||
# NOTE: let the user access any prompt as long as the Persona is shared
|
||||
# with them
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||
)
|
||||
if prompt is None:
|
||||
if not persona.prompts:
|
||||
raise RuntimeError(
|
||||
"Persona does not have any prompts - this should never happen"
|
||||
)
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
user_message_str = query_msg.message
|
||||
# For this endpoint, we only save one user message to the chat session
|
||||
# However, for slackbot, we want to include the history of the entire thread
|
||||
if danswerbot_flow:
|
||||
# Right now, we only support bringing over citations and search docs
|
||||
# from the last message in the thread, not the entire thread
|
||||
# in the future, we may want to retrieve the entire thread
|
||||
user_message_str = slackify_message_thread(query_req.messages)
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=user_message_str,
|
||||
token_count=len(llm_tokenizer.encode(user_message_str)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type
|
||||
),
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(
|
||||
get_llms_for_persona(persona=persona, long_term_logger=long_term_logger)
|
||||
),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool] if search_tool else [],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
# won't be any FileChatDisplay responses since that tool is never passed in
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
# (likely fine that it comes after the initial creation of the search docs)
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
top_docs = chunks_or_sections_to_search_docs(
|
||||
search_response_summary.top_sections
|
||||
)
|
||||
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
deduped_docs = top_docs
|
||||
if query_req.retrieval_options.dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=search_response_summary.predicted_flow,
|
||||
predicted_search=search_response_summary.predicted_search,
|
||||
applied_source_filters=search_response_summary.final_filters.source_type,
|
||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
document_based_response = {}
|
||||
|
||||
if packet.response is not None:
|
||||
for evaluation in packet.response:
|
||||
document_based_response[
|
||||
evaluation.document_id
|
||||
] = RelevanceAnalysis(
|
||||
relevant=evaluation.relevant, content=evaluation.content
|
||||
)
|
||||
|
||||
evaluation_response = DocumentRelevance(
|
||||
relevance_summaries=document_based_response
|
||||
)
|
||||
if reference_db_search_docs is not None:
|
||||
update_search_docs_table_with_relevance(
|
||||
db_session=db_session,
|
||||
reference_db_search_docs=reference_db_search_docs,
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
|
||||
else:
|
||||
yield packet
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=answer.llm_answer,
|
||||
token_count=len(llm_tokenizer.encode(answer.llm_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
yield msg_detail_response
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as session:
|
||||
objects = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=session,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
def get_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
enable_reflexion: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> OneShotQAResponse:
|
||||
"""Collects the streamed one shot answer responses into a single object"""
|
||||
qa_response = OneShotQAResponse()
|
||||
|
||||
results = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
timeout=answer_generation_timeout,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
answer = ""
|
||||
for packet in results:
|
||||
if isinstance(packet, QueryRephrase):
|
||||
qa_response.rephrase = packet.rephrased_query
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
qa_response.docs = packet
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, DanswerQuotes):
|
||||
qa_response.quotes = packet
|
||||
elif isinstance(packet, CitationInfo):
|
||||
if qa_response.citations:
|
||||
qa_response.citations.append(packet)
|
||||
else:
|
||||
qa_response.citations = [packet]
|
||||
elif isinstance(packet, DanswerContexts):
|
||||
qa_response.contexts = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
qa_response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
qa_response.chat_message_id = packet.message_id
|
||||
|
||||
if answer:
|
||||
qa_response.answer = answer
|
||||
|
||||
if enable_reflexion:
|
||||
# Because follow up messages are explicitly tagged, we don't need to verify the answer
|
||||
if len(query_req.messages) == 1:
|
||||
first_query = query_req.messages[0].message
|
||||
qa_response.answer_valid = get_answer_validity(first_query, answer)
|
||||
else:
|
||||
qa_response.answer_valid = True
|
||||
|
||||
if use_citations and qa_response.answer and qa_response.citations:
|
||||
# Reorganize citation nums to be in the same order as the answer
|
||||
qa_response.answer, qa_response.citations = reorganize_citations(
|
||||
qa_response.answer, qa_response.citations
|
||||
)
|
||||
|
||||
return qa_response
|
||||
114
backend/danswer/one_shot_answer/models.py
Normal file
114
backend/danswer/one_shot_answer/models.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import ChunkContext
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
|
||||
|
||||
class QueryRephrase(BaseModel):
|
||||
rephrased_query: str
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
persona_config: PersonaConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None = None
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
|
||||
chain_of_thought: bool = False
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "DirectQARequest":
|
||||
if (self.persona_config is None) == (self.persona_id is None):
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
raise ValueError(
|
||||
"If chain_of_thought is True, prompt_id must be None"
|
||||
"The chain of thought prompt is only for question "
|
||||
"answering and does not accept customizing."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
rephrase: str | None = None
|
||||
quotes: DanswerQuotes | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
chat_message_id: int | None = None
|
||||
contexts: DanswerContexts | None = None
|
||||
81
backend/danswer/one_shot_answer/qa_utils.py
Normal file
81
backend/danswer/one_shot_answer/qa_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
"""Mock streaming by generating the passed in model output, character by character"""
|
||||
for token in model_out:
|
||||
yield token
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def slackify_message(message: ThreadMessage) -> str:
|
||||
if message.role != MessageType.USER:
|
||||
return message.message
|
||||
|
||||
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"DanswerBot said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
@@ -5,11 +5,11 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
@@ -106,7 +105,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.models import Document as DbDocument
|
||||
@@ -115,7 +114,7 @@ class RedisConnectorDelete:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -12,7 +12,6 @@ from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
@@ -133,8 +132,6 @@ class RedisConnectorPermissionSync:
|
||||
lock: RedisLock | None,
|
||||
new_permissions: list[DocExternalAccess],
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
async_results = []
|
||||
@@ -152,13 +149,11 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
"update_external_document_permissions_task",
|
||||
kwargs=dict(
|
||||
tenant_id=self.tenant_id,
|
||||
serialized_doc_external_access=doc_perm.to_dict(),
|
||||
source_string=source_string,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
|
||||
|
||||
@@ -135,7 +134,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
@@ -77,7 +76,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -90,7 +89,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -3,14 +3,14 @@ from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.prompt_builder.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE
|
||||
from danswer.prompts.chat_prompts import NO_SEARCH
|
||||
from danswer.prompts.chat_prompts import REQUIRE_SEARCH_HINT
|
||||
|
||||
@@ -4,10 +4,10 @@ from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import HISTORY_QUERY_REPHRASE
|
||||
|
||||
@@ -20,7 +20,6 @@ from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.connectors.google_utils.google_auth import (
|
||||
@@ -86,7 +85,6 @@ from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_processing.extract_file_text import convert_docx_to_txt
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
@@ -394,12 +392,6 @@ def upload_files(
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
file_type=file.content_type or "text/plain",
|
||||
)
|
||||
|
||||
if file.content_type and file.content_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
convert_docx_to_txt(file, file_store, file_path)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return FileUploadResponse(file_paths=deduped_file_paths)
|
||||
@@ -875,7 +867,7 @@ def connector_run_once(
|
||||
|
||||
# run the beat task to pick up the triggers immediately
|
||||
primary_app.send_task(
|
||||
DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"check_for_indexing",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
@@ -1017,18 +1009,37 @@ def get_connector_by_id(
|
||||
|
||||
|
||||
class BasicCCPairInfo(BaseModel):
|
||||
docs_indexed: int
|
||||
has_successful_run: bool
|
||||
source: DocumentSource
|
||||
|
||||
|
||||
@router.get("/connector-status")
|
||||
@router.get("/indexing-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_identifiers = [
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
]
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
)
|
||||
cc_pair_to_document_cnt = {
|
||||
(connector_id, credential_id): cnt
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
return [
|
||||
BasicCCPairInfo(
|
||||
docs_indexed=cc_pair_to_document_cnt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id)
|
||||
)
|
||||
or 0,
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.prompt_builder.utils import build_dummy_prompt
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.engine import get_session
|
||||
@@ -34,6 +33,7 @@ from danswer.db.persona import update_persona_shared_users
|
||||
from danswer.db.persona import update_persona_visibility
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import ImageGenerationToolStatus
|
||||
from danswer.server.features.persona.models import PersonaCategoryCreate
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -200,7 +199,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
|
||||
# run the beat task to pick up this deletion from the db immediately
|
||||
primary_app.send_task(
|
||||
DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"check_for_connector_deletion_task",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +15,7 @@ from danswer.danswerbot.slack.config import VALID_SLACK_FILTERS
|
||||
from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBot as SlackAppModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig as SlackChannelConfigModel
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -148,12 +148,6 @@ class SlackBotTokens(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
# TODO No longer in use, remove later
|
||||
class SlackBotResponseType(str, Enum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
|
||||
|
||||
class SlackChannelConfigCreationRequest(BaseModel):
|
||||
slack_bot_id: int
|
||||
# currently, a persona is created for each Slack channel config
|
||||
@@ -203,6 +197,7 @@ class SlackChannelConfig(BaseModel):
|
||||
id: int
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
@@ -222,6 +217,7 @@ class SlackChannelConfig(BaseModel):
|
||||
else None
|
||||
),
|
||||
channel_config=slack_channel_config_model.channel_config,
|
||||
response_type=slack_channel_config_model.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
|
||||
@@ -118,6 +118,7 @@ def create_slack_channel_config(
|
||||
slack_bot_id=slack_channel_config_creation_request.slack_bot_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_channel_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
@@ -181,6 +182,7 @@ def patch_slack_channel_config(
|
||||
slack_channel_config_id=slack_channel_config_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_channel_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from danswer.auth.noauth_user import fetch_no_auth_user
|
||||
from danswer.auth.noauth_user import set_no_auth_user_preferences
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserStatus
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
@@ -33,6 +34,7 @@ from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.api_key import is_api_key_email_address
|
||||
@@ -59,11 +61,9 @@ from danswer.server.manage.models import UserRoleUpdateRequest
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.server.utils import send_user_email_invite
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.configs.app_configs import SUPER_USERS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -194,11 +194,11 @@ def bulk_invite_users(
|
||||
)
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
new_invited_emails = []
|
||||
normalized_emails = []
|
||||
try:
|
||||
for email in emails:
|
||||
email_info = validate_email(email)
|
||||
new_invited_emails.append(email_info.normalized)
|
||||
normalized_emails.append(email_info.normalized) # type: ignore
|
||||
|
||||
except (EmailUndeliverableError, EmailNotValidError) as e:
|
||||
raise HTTPException(
|
||||
@@ -210,7 +210,7 @@ def bulk_invite_users(
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
)(normalized_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
@@ -224,7 +224,7 @@ def bulk_invite_users(
|
||||
|
||||
initial_invited_users = get_invited_users()
|
||||
|
||||
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
|
||||
all_emails = list(set(normalized_emails) | set(initial_invited_users))
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
if not MULTI_TENANT:
|
||||
@@ -236,7 +236,7 @@ def bulk_invite_users(
|
||||
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in new_invited_emails:
|
||||
for email in all_emails:
|
||||
send_user_email_invite(email, current_user)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
@@ -250,7 +250,7 @@ def bulk_invite_users(
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
)(normalized_emails, tenant_id)
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
@@ -109,7 +109,6 @@ def process_run_in_background(
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@@ -24,9 +23,6 @@ from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import extract_headers
|
||||
from danswer.chat.process_message import stream_chat_message
|
||||
from danswer.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -51,11 +47,13 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
@@ -709,30 +707,14 @@ def upload_files_for_chat(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/file/{file_id:path}")
|
||||
@router.get("/file/{file_id}")
|
||||
def fetch_chat_file(
|
||||
file_id: str,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> Response:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
original_file_name = file_record.display_name
|
||||
if file_record.file_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
# Check if a converted text file exists for .docx files
|
||||
txt_file_name = docx_to_txt_filename(original_file_name)
|
||||
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
|
||||
txt_file_record = file_store.read_file_record(txt_file_id)
|
||||
if txt_file_record:
|
||||
file_record = txt_file_record
|
||||
file_id = txt_file_id
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
return StreamingResponse(file_io, media_type=media_type)
|
||||
# NOTE: specifying "image/jpeg" here, but it still works for pngs
|
||||
# TODO: do this properly
|
||||
return Response(content=file_io.read(), media_type="image/jpeg")
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import RetrievalDocs
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import ChunkContext
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.context.search.models import Tag
|
||||
@@ -23,9 +20,6 @@ from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
source: DocumentSource
|
||||
@@ -93,8 +87,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None
|
||||
retrieval_options: RetrievalDetails | None
|
||||
# Useable via the APIs but not recommended for most flows
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
@@ -110,10 +102,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# allow user to specify an alternate assistnat
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
# This takes the priority over the prompt_override
|
||||
# This won't be a type that's passed in directly from the API
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
@@ -157,7 +145,7 @@ class RenameChatSessionResponse(BaseModel):
|
||||
|
||||
class ChatSessionDetails(BaseModel):
|
||||
id: UUID
|
||||
name: str | None
|
||||
name: str
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
@@ -210,14 +198,14 @@ class ChatMessageDetail(BaseModel):
|
||||
|
||||
class SearchSessionDetailResponse(BaseModel):
|
||||
search_session_id: UUID
|
||||
description: str | None
|
||||
description: str
|
||||
documents: list[SearchDoc]
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: UUID
|
||||
description: str | None
|
||||
description: str
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -28,6 +32,8 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tag import find_tags
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.one_shot_answer.answer_question import stream_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.server.query_and_chat.models import AdminSearchRequest
|
||||
from danswer.server.query_and_chat.models import AdminSearchResponse
|
||||
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||
@@ -35,6 +41,7 @@ from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
from danswer.server.query_and_chat.models import SearchSessionDetailResponse
|
||||
from danswer.server.query_and_chat.models import SourceTag
|
||||
from danswer.server.query_and_chat.models import TagResponse
|
||||
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -133,7 +140,7 @@ def get_user_search_sessions(
|
||||
|
||||
try:
|
||||
search_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session
|
||||
user_id=user_id, deleted=False, db_session=db_session, only_one_shot=True
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
@@ -222,3 +229,29 @@ def get_search_session(
|
||||
],
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@basic_router.post("/stream-answer-with-quote")
|
||||
def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User = Depends(current_limited_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
query = query_request.messages[0].message
|
||||
|
||||
logger.notice(f"Received query for one shot answer with quotes: {query}")
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_search_answer(
|
||||
query_req=query_request,
|
||||
user=user,
|
||||
max_document_tokens=None,
|
||||
max_history_tokens=0,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
except Exception as e:
|
||||
logger.exception("Error in search answer streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
|
||||
@@ -6,9 +6,6 @@ from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
@@ -17,11 +14,6 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import MANAGED_VESPA
|
||||
from danswer.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
@@ -222,13 +221,13 @@ def setup_vespa(
|
||||
document_index: DocumentIndex,
|
||||
index_setting: IndexingSetting,
|
||||
secondary_index_setting: IndexingSetting | None,
|
||||
num_attempts: int = VESPA_NUM_ATTEMPTS_ON_STARTUP,
|
||||
) -> bool:
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
WAIT_SECONDS = 5
|
||||
for x in range(num_attempts):
|
||||
VESPA_ATTEMPTS = 5
|
||||
for x in range(VESPA_ATTEMPTS):
|
||||
try:
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...")
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
@@ -245,7 +244,7 @@ def setup_vespa(
|
||||
time.sleep(WAIT_SECONDS)
|
||||
|
||||
logger.error(
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({num_attempts})"
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CustomToolCallSummary,
|
||||
)
|
||||
|
||||
@@ -3,13 +3,13 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
@@ -17,12 +13,15 @@ from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
@@ -103,14 +102,11 @@ class SearchToolConfig(BaseModel):
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
full_doc: bool = False
|
||||
latest_query_files: list[InMemoryChatFile] | None = None
|
||||
# Use with care, should only be used for DanswerBot in channels with multiple users
|
||||
bypass_acl: bool = False
|
||||
|
||||
|
||||
class InternetSearchToolConfig(BaseModel):
|
||||
@@ -174,8 +170,6 @@ def construct_tools(
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
|
||||
@@ -15,14 +15,14 @@ from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.tools.base_tool import BaseTool
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
|
||||
@@ -4,16 +4,14 @@ from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from litellm import image_generation # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
@@ -58,18 +56,9 @@ Follow Up Input:
|
||||
""".strip()
|
||||
|
||||
|
||||
class ImageFormat(str, Enum):
|
||||
URL = "url"
|
||||
BASE64 = "b64_json"
|
||||
|
||||
|
||||
_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT)
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
revised_prompt: str
|
||||
url: str | None
|
||||
image_data: str | None
|
||||
url: str
|
||||
|
||||
|
||||
class ImageShape(str, Enum):
|
||||
@@ -91,7 +80,6 @@ class ImageGenerationTool(Tool):
|
||||
model: str = "dall-e-3",
|
||||
num_imgs: int = 2,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
@@ -101,7 +89,6 @@ class ImageGenerationTool(Tool):
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
self.additional_headers = additional_headers
|
||||
self.output_format = output_format
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -181,7 +168,7 @@ class ImageGenerationTool(Tool):
|
||||
)
|
||||
|
||||
return build_content_with_imgs(
|
||||
message=json.dumps(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"revised_prompt": image_generation.revised_prompt,
|
||||
@@ -190,10 +177,13 @@ class ImageGenerationTool(Tool):
|
||||
for image_generation in image_generations
|
||||
]
|
||||
),
|
||||
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
|
||||
# Tool messages to contain images
|
||||
# img_urls=[image_generation.url for image_generation in image_generations],
|
||||
)
|
||||
|
||||
def _generate_image(
|
||||
self, prompt: str, shape: ImageShape, format: ImageFormat
|
||||
self, prompt: str, shape: ImageShape
|
||||
) -> ImageGenerationResponse:
|
||||
if shape == ImageShape.LANDSCAPE:
|
||||
size = "1792x1024"
|
||||
@@ -207,32 +197,20 @@ class ImageGenerationTool(Tool):
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
api_key=self.api_key,
|
||||
# need to pass in None rather than empty str
|
||||
api_base=self.api_base or None,
|
||||
api_version=self.api_version or None,
|
||||
size=size,
|
||||
n=1,
|
||||
response_format=format,
|
||||
extra_headers=build_llm_extra_headers(self.additional_headers),
|
||||
)
|
||||
|
||||
if format == ImageFormat.URL:
|
||||
url = response.data[0]["url"]
|
||||
image_data = None
|
||||
else:
|
||||
url = None
|
||||
image_data = response.data[0]["b64_json"]
|
||||
|
||||
return ImageGenerationResponse(
|
||||
revised_prompt=response.data[0]["revised_prompt"],
|
||||
url=url,
|
||||
image_data=image_data,
|
||||
url=response.data[0]["url"],
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error fetching or converting image: {e}")
|
||||
raise ValueError("Failed to fetch or convert the generated image")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error occurred during image generation: {e}")
|
||||
logger.debug(f"Error occured during image generation: {e}")
|
||||
|
||||
error_message = str(e)
|
||||
if "OpenAIException" in str(type(e)):
|
||||
@@ -257,8 +235,9 @@ class ImageGenerationTool(Tool):
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
prompt = cast(str, kwargs["prompt"])
|
||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||
format = self.output_format
|
||||
|
||||
# dalle3 only supports 1 image at a time, which is why we have to
|
||||
# parallelize this via threading
|
||||
results = cast(
|
||||
list[ImageGenerationResponse],
|
||||
run_functions_tuples_in_parallel(
|
||||
@@ -268,7 +247,6 @@ class ImageGenerationTool(Tool):
|
||||
(
|
||||
prompt,
|
||||
shape,
|
||||
format,
|
||||
),
|
||||
)
|
||||
for _ in range(self.num_imgs)
|
||||
@@ -310,17 +288,11 @@ class ImageGenerationTool(Tool):
|
||||
if img_generation_response is None:
|
||||
raise ValueError("No image generation response found")
|
||||
|
||||
img_urls = [img.url for img in img_generation_response if img.url is not None]
|
||||
b64_imgs = [
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data is not None
|
||||
]
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
img_urls=img_urls,
|
||||
b64_imgs=b64_imgs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -11,14 +11,11 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
query: str, img_urls: list[str] | None = None
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
b64_imgs=b64_imgs,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -7,15 +7,15 @@ from typing import cast
|
||||
import httpx
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
@@ -77,7 +77,6 @@ def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc:
|
||||
updated_at=datetime.now(),
|
||||
link=result.link,
|
||||
source_links={0: result.link},
|
||||
match_highlights=[],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,19 +7,10 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.llm_response_handler import LLMCall
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ContextualPruningConfig
|
||||
from danswer.chat.models import DanswerContext
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.chat.prune_and_merge import prune_and_merge_sections
|
||||
from danswer.chat.prune_and_merge import prune_sections
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -28,14 +19,22 @@ from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.answering.prune_and_merge import prune_and_merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import prune_sections
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
@@ -104,7 +103,6 @@ class SearchTool(Tool):
|
||||
chunks_below: int | None = None,
|
||||
full_doc: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> None:
|
||||
self.user = user
|
||||
self.persona = persona
|
||||
@@ -120,9 +118,6 @@ class SearchTool(Tool):
|
||||
self.bypass_acl = bypass_acl
|
||||
self.db_session = db_session
|
||||
|
||||
# Only used via API
|
||||
self.rerank_settings = rerank_settings
|
||||
|
||||
self.chunks_above = (
|
||||
chunks_above
|
||||
if chunks_above is not None
|
||||
@@ -297,7 +292,6 @@ class SearchTool(Tool):
|
||||
self.retrieval_options.offset if self.retrieval_options else None
|
||||
),
|
||||
limit=self.retrieval_options.limit if self.retrieval_options else None,
|
||||
rerank_settings=self.rerank_settings,
|
||||
chunks_above=self.chunks_above,
|
||||
chunks_below=self.chunks_below,
|
||||
full_doc=self.full_doc,
|
||||
|
||||
@@ -2,15 +2,15 @@ from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.citations_prompt import (
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
)
|
||||
from danswer.chat.prompt_builder.citations_prompt import build_citations_user_message
|
||||
from danswer.chat.prompt_builder.quotes_prompt import build_quotes_user_message
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
|
||||
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user