mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 01:05:46 +00:00
Compare commits
47 Commits
additional
...
updated_ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d97e96b3f0 | ||
|
|
911fbfa5a6 | ||
|
|
d02305671a | ||
|
|
bdfa29dcb5 | ||
|
|
897ed03c19 | ||
|
|
49f0c4f1f8 | ||
|
|
338c02171b | ||
|
|
ef1ade84b6 | ||
|
|
7c81566c54 | ||
|
|
c9df0aea47 | ||
|
|
92e0aeecba | ||
|
|
30c7e07783 | ||
|
|
e99704e9bd | ||
|
|
7f36387f7f | ||
|
|
407592445b | ||
|
|
2e533d8188 | ||
|
|
5b56869937 | ||
|
|
7baeab54e2 | ||
|
|
aefcfb75ef | ||
|
|
e5adcb457d | ||
|
|
db6463644a | ||
|
|
e26ba70cc6 | ||
|
|
66ff723c94 | ||
|
|
dda66f2178 | ||
|
|
0a27f72d20 | ||
|
|
fe397601ed | ||
|
|
3bc187c1d1 | ||
|
|
9a0b9eecf0 | ||
|
|
e08db414c0 | ||
|
|
b5734057b7 | ||
|
|
56beb3ec82 | ||
|
|
9f2c8118d7 | ||
|
|
6e4a3d5d57 | ||
|
|
5b3dcf718f | ||
|
|
07bd20b5b9 | ||
|
|
eb01b175ae | ||
|
|
6f55e5fe56 | ||
|
|
18e7609bfc | ||
|
|
dd69ec6cdb | ||
|
|
e961fa2820 | ||
|
|
d41bf9a3ff | ||
|
|
e3a6c76d51 | ||
|
|
719c2aa0df | ||
|
|
09f487e402 | ||
|
|
33a1548fc1 | ||
|
|
e87c93226a | ||
|
|
5e11a79593 |
@@ -24,8 +24,6 @@ env:
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -73,7 +73,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
@@ -8,7 +8,6 @@ from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
@@ -36,18 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -18,11 +18,6 @@ class ExternalAccess:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
@@ -87,7 +87,6 @@ from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -100,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -333,16 +332,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock: RedisLock = sender.primary_worker_lock
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception:
|
||||
logger.exception("Failed to release primary worker lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to check if primary worker lock is owned")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -39,6 +38,7 @@ from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@@ -116,13 +116,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
|
||||
# set thread_local=False since we don't control what thread the periodic task might
|
||||
# reacquire the lock with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
@@ -231,7 +227,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
lock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
||||
@@ -2,55 +2,54 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"task": "check_for_external_group_sync",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -29,7 +28,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -18,11 +18,9 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -84,7 +82,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
name="check_for_doc_permissions_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -166,7 +164,7 @@ def try_creating_permissions_sync_task(
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -193,7 +191,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
name="connector_permission_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -263,12 +261,7 @@ def connector_permission_sync_generator_task(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
@@ -293,7 +286,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
name="update_external_document_permissions_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
@@ -304,8 +297,6 @@ def update_external_document_permissions_task(
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
@@ -314,28 +305,18 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
# Then we build the update requests to update vespa
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -32,14 +31,10 @@ from danswer.redis.redis_connector_ext_group_sync import (
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -90,7 +85,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
name="check_for_external_group_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -111,22 +106,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session, source, only_sync=True
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.id != cc_pair_to_remove.id
|
||||
]
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
@@ -182,7 +161,7 @@ def try_creating_external_group_sync_task(
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -212,7 +191,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
name="connector_external_group_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
|
||||
@@ -23,7 +23,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
@@ -157,7 +156,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
@@ -487,7 +486,7 @@ def try_creating_indexing_task(
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -525,10 +524,7 @@ def try_creating_indexing_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
|
||||
)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
@@ -584,64 +580,39 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
"Indexing proxy - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
finally:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
continue
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -789,12 +760,9 @@ def connector_indexing_task(
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -13,13 +13,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
|
||||
@@ -8,7 +8,6 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -21,7 +20,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
@@ -77,7 +75,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -186,7 +184,7 @@ def try_creating_prune_generator_task(
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -211,7 +209,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
name="connector_pruning_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -240,12 +238,9 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -9,7 +9,6 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
@@ -32,7 +31,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
@@ -81,7 +80,7 @@ logger = setup_logger()
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -655,28 +654,24 @@ def monitor_ccpair_indexing_taskset(
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
result_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
result_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"result_state={result_state} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
@@ -708,7 +703,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
@@ -819,7 +814,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
|
||||
@@ -2,79 +2,20 @@ import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
prompt: Prompt | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
@@ -90,49 +31,9 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -295,71 +196,3 @@ def extract_headers(
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
@@ -4,14 +4,12 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.context.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
@@ -27,7 +25,6 @@ class LlmDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@@ -120,6 +117,20 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
@@ -135,20 +146,14 @@ class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class ChatDanswerBotResponse(BaseModel):
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
@@ -160,41 +165,9 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
|
||||
@@ -7,13 +7,10 @@ from typing import cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
@@ -105,7 +102,6 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -117,10 +113,8 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -262,7 +256,6 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| DanswerContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -293,8 +286,6 @@ def stream_chat_message_objects(
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -331,31 +322,17 @@ def stream_chat_message_objects(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
||||
@@ -578,34 +555,19 @@ def stream_chat_message_objects(
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].system_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
include_citations=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].include_citations,
|
||||
)
|
||||
elif prompt_override:
|
||||
if not final_msg.prompt:
|
||||
raise ValueError(
|
||||
"Prompt override cannot be applied, no base prompt found."
|
||||
)
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=prompt_override,
|
||||
)
|
||||
elif final_msg.prompt:
|
||||
prompt_config = PromptConfig.from_model(final_msg.prompt)
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
@@ -625,13 +587,11 @@ def stream_chat_message_objects(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
bypass_acl=bypass_acl,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
@@ -777,8 +737,6 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(DanswerContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
@@ -887,30 +845,3 @@ def stream_chat_message(
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_slack(
|
||||
packets: ChatPacketStream,
|
||||
) -> ChatDanswerBotResponse:
|
||||
response = ChatDanswerBotResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
|
||||
@@ -308,22 +308,6 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
# The current state of affairs:
|
||||
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
|
||||
# no API retrieves the user's timezone
|
||||
# All data is returned in UTC, so we can't derive the user's timezone from that
|
||||
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -509,6 +493,10 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
@@ -522,6 +510,3 @@ API_KEY_HASH_ROUNDS = (
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
@@ -31,8 +31,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -261,32 +259,6 @@ class DanswerCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
class DanswerCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
|
||||
"update_external_document_permissions_task"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
@@ -4,8 +4,11 @@ import os
|
||||
# Danswer Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
)
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
@@ -44,6 +47,17 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
# if set, will default DanswerBot to use quotes and reference documents
|
||||
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
|
||||
@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
|
||||
- Load Connector:
|
||||
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll Connector:
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Slim Connector:
|
||||
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
|
||||
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
|
||||
- This is used by our pruning job which removes old documents from the index.
|
||||
- The optional start and end datetimes can be ignored.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
- Currently not used by the background job, this exists for future design purposes.
|
||||
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
|
||||
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
|
||||
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
|
||||
|
||||
For implementing a Slim Connector, refer to the comments in this PR:
|
||||
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
|
||||
|
||||
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
|
||||
|
||||
|
||||
#### Implementing the new Connector
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -71,7 +69,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
@@ -107,8 +104,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -209,14 +204,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
@@ -249,10 +242,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
|
||||
@@ -134,32 +134,6 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -332,13 +306,6 @@ def _validate_connector_configuration(
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
|
||||
@@ -32,11 +32,7 @@ def get_user_email_from_username__server(
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
# For now, we'll just return a string that indicates failure
|
||||
# We may want to revert to returning None in the future
|
||||
# email = None
|
||||
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
@@ -12,15 +12,12 @@ from dateutil import parser
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -31,8 +28,6 @@ logger = setup_logger()
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
|
||||
@@ -163,26 +158,21 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
slab_bot_token: str | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
self._slab_bot_token: str | None = None
|
||||
self.slab_bot_token = slab_bot_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._slab_bot_token = credentials["slab_bot_token"]
|
||||
self.slab_bot_token = credentials["slab_bot_token"]
|
||||
return None
|
||||
|
||||
@property
|
||||
def slab_bot_token(self) -> str:
|
||||
if self._slab_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Slab")
|
||||
return self._slab_bot_token
|
||||
|
||||
def _iterate_posts(
|
||||
self, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -237,21 +227,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
yield from self._iterate_posts(
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=post_id,
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -16,7 +16,7 @@ from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.blocks.basic_components import MarkdownTextObject
|
||||
from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -40,7 +40,10 @@ from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.db.chat import get_chat_session_by_message_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
_MAX_BLURB_LEN = 45
|
||||
|
||||
@@ -324,7 +327,7 @@ def _build_sources_blocks(
|
||||
|
||||
|
||||
def _priority_ordered_documents_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -347,7 +350,7 @@ def _priority_ordered_documents_blocks(
|
||||
|
||||
|
||||
def _build_citations_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -366,8 +369,51 @@ def _build_citations_blocks(
|
||||
return citations_block
|
||||
|
||||
|
||||
def _build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
doc_to_quotes: dict[str, list[str]] = {}
|
||||
doc_to_link: dict[str, str] = {}
|
||||
doc_to_sem_id: dict[str, str] = {}
|
||||
for q in quotes:
|
||||
quote = q.quote
|
||||
doc_id = q.document_id
|
||||
doc_link = q.link
|
||||
doc_name = q.semantic_identifier
|
||||
if doc_link and doc_name and doc_id and quote:
|
||||
if doc_id not in doc_to_quotes:
|
||||
doc_to_quotes[doc_id] = [quote]
|
||||
doc_to_link[doc_id] = doc_link
|
||||
doc_to_sem_id[doc_id] = (
|
||||
doc_name
|
||||
if q.source_type != DocumentSource.SLACK.value
|
||||
else "#" + doc_name
|
||||
)
|
||||
else:
|
||||
doc_to_quotes[doc_id].append(quote)
|
||||
|
||||
for doc_id, quote_strs in doc_to_quotes.items():
|
||||
quotes_str_clean = [
|
||||
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
|
||||
]
|
||||
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
|
||||
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
|
||||
link = doc_to_link[doc_id]
|
||||
sem_id = doc_to_sem_id[doc_id]
|
||||
quote_lines.append(
|
||||
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
|
||||
)
|
||||
|
||||
if not doc_to_quotes:
|
||||
return []
|
||||
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
skip_quotes: bool = False,
|
||||
process_message_for_citations: bool = False,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
@@ -376,10 +422,13 @@ def _build_qa_response_blocks(
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
quotes = answer.quotes.quotes if answer.quotes else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
filter_block: Block | None = None
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
@@ -422,6 +471,16 @@ def _build_qa_response_blocks(
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
if quotes:
|
||||
quotes_blocks = _build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `_build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
|
||||
)
|
||||
]
|
||||
|
||||
response_blocks: list[Block] = []
|
||||
|
||||
@@ -430,6 +489,9 @@ def _build_qa_response_blocks(
|
||||
|
||||
response_blocks.extend(answer_blocks)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
|
||||
return response_blocks
|
||||
|
||||
|
||||
@@ -505,9 +567,10 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
answer: OneShotQAResponse,
|
||||
persona: Persona | None,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
@@ -524,6 +587,7 @@ def build_slack_response_blocks(
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
@@ -553,7 +617,8 @@ def build_slack_response_blocks(
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if use_citations and answer.citations:
|
||||
if use_citations:
|
||||
# if citations are enabled, only show cited documents
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
@@ -572,5 +637,4 @@ def build_slack_response_blocks(
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
)
|
||||
|
||||
return all_blocks
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -8,36 +9,46 @@ from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.chat.chat_utils import prepare_chat_message_request
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.process_message import gather_stream_for_slack
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.handlers.utils import slackify_message_thread
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
@@ -72,14 +83,16 @@ def handle_regular_answer(
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
@@ -89,18 +102,9 @@ def handle_regular_answer(
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
prompt = None
|
||||
# If no persona is specified, use the default search based persona
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
else:
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
@@ -108,26 +112,6 @@ def handle_regular_answer(
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# TODO: Add in support for Slack to truncate messages based on max LLM context
|
||||
# llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=llm.config.model_name,
|
||||
# provider_type=llm.config.model_provider,
|
||||
# )
|
||||
|
||||
# # In cases of threads, split the available tokens between docs and thread context
|
||||
# input_tokens = get_max_input_tokens(
|
||||
# model_name=llm.config.model_name,
|
||||
# model_provider=llm.config.model_provider,
|
||||
# )
|
||||
# max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
# combined_message = combine_message_thread(
|
||||
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
|
||||
# )
|
||||
|
||||
combined_message = slackify_message_thread(messages)
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_channel_config
|
||||
@@ -138,6 +122,13 @@ def handle_regular_answer(
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_channel_config is None
|
||||
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
@@ -150,23 +141,75 @@ def handle_regular_answer(
|
||||
backoff=2,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, danswer_user: User | None
|
||||
) -> ChatDanswerBotResponse:
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=danswer_user,
|
||||
if len(new_message_request.messages) > 1:
|
||||
if new_message_request.persona_config:
|
||||
raise RuntimeError("Slack bot does not support persona config")
|
||||
elif new_message_request.persona_id is not None:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No persona id provided, this should never happen."
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
answer = gather_stream_for_slack(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
@@ -196,24 +239,26 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
user=user,
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
multilingual_query_expansion=saved_search_settings.multilingual_expansion
|
||||
if saved_search_settings
|
||||
else None,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
answer = _get_slack_answer(
|
||||
new_message_request=answer_request, danswer_user=user
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -314,7 +359,7 @@ def handle_regular_answer(
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{combined_message}' - no documents found"
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -335,18 +380,18 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_if_citations = (
|
||||
only_respond_with_citations_or_quotes = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
if (
|
||||
only_respond_if_citations
|
||||
and not answer.citations
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -364,8 +409,9 @@ def handle_regular_answer(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
persona=persona,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True, # No longer supporting quotes
|
||||
use_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,33 +1,8 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
# Note: this does not handle extremely long threads, every message will be included
|
||||
# with weaker LLMs, this could cause issues with exceeeding the token limit
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"AI said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
|
||||
@@ -19,8 +19,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.app_configs import DEV_MODE
|
||||
from danswer.configs.app_configs import POD_NAME
|
||||
from danswer.configs.app_configs import POD_NAMESPACE
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -76,6 +74,7 @@ from danswer.db.slack_bot import fetch_slack_bots
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -251,7 +250,7 @@ class SlackbotHandler:
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
if not acquired and not DEV_MODE:
|
||||
if not acquired:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
|
||||
@@ -30,13 +30,13 @@ from danswer.configs.danswerbot_configs import (
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import FeedbackVisibility
|
||||
from danswer.danswerbot.slack.models import ThreadMessage
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
|
||||
@@ -145,10 +145,16 @@ def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
only_one_shot: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if only_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
@@ -220,11 +226,12 @@ def delete_messages_and_files_from_chat_session(
|
||||
|
||||
def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str | None,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
@@ -234,6 +241,7 @@ def create_chat_session(
|
||||
description=description,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -279,6 +287,8 @@ def duplicate_chat_session_for_user_from_slack(
|
||||
description="",
|
||||
llm_override=chat_session.llm_override,
|
||||
prompt_override=chat_session.prompt_override,
|
||||
# Chat sessions from Slack should put people in the chat UI, not the search
|
||||
one_shot=False,
|
||||
# Chat is in UI now so this is false
|
||||
danswerbot_flow=False,
|
||||
# Maybe we want this in the future to track if it was created from Slack
|
||||
|
||||
@@ -37,7 +37,6 @@ from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -427,9 +426,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
@@ -963,8 +964,9 @@ class ChatSession(Base):
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# This chat created by DanswerBot
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
danswerbot_flow: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Only ever set to True if system is set to not hard-delete chats
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -1486,6 +1488,11 @@ class ChannelConfig(TypedDict):
|
||||
show_continue_in_web_ui: NotRequired[bool] # defaults to False
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
|
||||
|
||||
class SlackChannelConfig(Base):
|
||||
__tablename__ = "slack_channel_config"
|
||||
|
||||
@@ -1498,6 +1505,9 @@ class SlackChannelConfig(Base):
|
||||
channel_config: Mapped[ChannelConfig] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False
|
||||
)
|
||||
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
|
||||
@@ -415,6 +415,9 @@ def upsert_prompt(
|
||||
return prompt
|
||||
|
||||
|
||||
# NOTE: This operation cannot update persona configuration options that
|
||||
# are core to the persona, such as its display priority and
|
||||
# whether or not the assistant is a built-in / default assistant
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -446,12 +449,6 @@ def upsert_persona(
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
"""
|
||||
NOTE: This operation cannot update persona configuration options that
|
||||
are core to the persona, such as its display priority and
|
||||
whether or not the assistant is a built-in / default assistant
|
||||
"""
|
||||
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
else:
|
||||
@@ -489,8 +486,6 @@ def upsert_persona(
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
# Built-in personas can only be updated through YAML configuration.
|
||||
# This ensures that core system personas are not modified unintentionally.
|
||||
if persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
@@ -499,9 +494,6 @@ def upsert_persona(
|
||||
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
persona.name = name
|
||||
persona.description = description
|
||||
persona.num_chunks = num_chunks
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
@@ -82,6 +83,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -113,6 +115,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id=slack_bot_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=response_type,
|
||||
standard_answer_categories=existing_standard_answer_categories,
|
||||
enable_auto_filters=enable_auto_filters,
|
||||
)
|
||||
@@ -127,6 +130,7 @@ def update_slack_channel_config(
|
||||
slack_channel_config_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -166,6 +170,7 @@ def update_slack_channel_config(
|
||||
# will encounter `violates foreign key constraint` errors
|
||||
slack_channel_config.persona_id = persona_id
|
||||
slack_channel_config.channel_config = channel_config
|
||||
slack_channel_config.response_type = response_type
|
||||
slack_channel_config.standard_answer_categories = list(
|
||||
existing_standard_answer_categories
|
||||
)
|
||||
|
||||
@@ -59,12 +59,6 @@ class FileStore(ABC):
|
||||
Contents of the file and metadata dict
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read_file_record(self, file_name: str) -> PGFileStore:
|
||||
"""
|
||||
Read the file record by the name
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -18,12 +18,18 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
@@ -208,23 +214,18 @@ class Answer:
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
)
|
||||
else:
|
||||
raise ValueError("No answer style config provided")
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
@@ -29,6 +30,7 @@ if TYPE_CHECKING:
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
|
||||
@@ -9,6 +9,9 @@ from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -67,29 +70,28 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# yield from self.quotes_processor.process_token(content)
|
||||
yield from self.quotes_processor.process_token(content)
|
||||
|
||||
@@ -67,9 +67,9 @@ class CitationProcessor:
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
@@ -77,15 +77,13 @@ class CitationProcessor:
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
result = ""
|
||||
result = "" # Initialize result here
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
@@ -133,6 +131,14 @@ class CitationProcessor:
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
@@ -143,7 +149,6 @@ class CitationProcessor:
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# THIS IS NO LONGER IN USE
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
@@ -6,10 +5,11 @@ from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
@@ -26,20 +26,6 @@ logger = setup_logger()
|
||||
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -14,11 +13,8 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
@@ -36,15 +32,11 @@ def get_main_llm_from_tuple(
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
persona: Persona,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if persona is None:
|
||||
logger.warning("No persona provided, using default LLMs")
|
||||
return get_default_llms()
|
||||
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
@@ -79,7 +71,6 @@ def get_llms_for_persona(
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
@@ -137,13 +128,11 @@ def get_llm(
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import create_danswer_oauth_router
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
@@ -91,7 +92,6 @@ from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.setup import setup_multitenant_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -206,7 +206,7 @@ def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
|
||||
if isinstance(exc, BasicAuthenticationError):
|
||||
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
|
||||
logger.warning(f"Authentication failed: {str(exc)}")
|
||||
logger.error(f"Authentication failed: {str(exc)}")
|
||||
|
||||
elif status_code >= 400:
|
||||
error_msg = f"{str(exc)}\n"
|
||||
|
||||
0
backend/danswer/one_shot_answer/__init__.py
Normal file
0
backend/danswer/one_shot_answer/__init__.py
Normal file
456
backend/danswer/one_shot_answer/answer_question.py
Normal file
456
backend/danswer/one_shot_answer/answer_question.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import reorganize_citations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import DocumentRelevance
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import RelevanceAnalysis
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.models import RerankMetricsContainer
|
||||
from danswer.context.search.models import RetrievalMetricsContainer
|
||||
from danswer.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.context.search.utils import dedupe_documents
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.one_shot_answer.qa_utils import slackify_message_thread
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
AnswerObjectIterator = Iterator[
|
||||
QueryRephrase
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| DanswerContexts
|
||||
| StreamingError
|
||||
| ChatMessageDetail
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| DocumentRelevance
|
||||
]
|
||||
|
||||
|
||||
def stream_answer_objects(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
# These need to be passed in because in Web UI one shot flow,
|
||||
# we can have much more document as there is no history.
|
||||
# For Slack flow, we need to save more tokens for the thread context
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> AnswerObjectIterator:
|
||||
"""Streams in order:
|
||||
1. [always] Retrieved documents, stops flow if nothing is found
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed DanswerAnswerPiece and DanswerQuotes at the end
|
||||
or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
user_id = user.id if user is not None else None
|
||||
query_msg = query_req.messages[-1]
|
||||
history = query_req.messages[:-1]
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # One shot queries don't need naming as it's never displayed
|
||||
user_id=user_id,
|
||||
persona_id=query_req.persona_id,
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session.id)}
|
||||
)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
|
||||
if query_req.persona_config is not None:
|
||||
temporary_persona = fetch_ee_implementation_or_noop(
|
||||
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
|
||||
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
try:
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona, long_term_logger=long_term_logger
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Failed to initialize LLMs for persona '{persona.name}': {str(e)}"
|
||||
)
|
||||
if "No LLM provider" in str(e):
|
||||
raise ValueError(
|
||||
"Please configure a Generative AI model to use this feature."
|
||||
) from e
|
||||
raise ValueError(
|
||||
"Failed to initialize the AI model. Please check your configuration and try again."
|
||||
) from e
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Create a chat session which will just store the root message, the query, and the AI response
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = query_req.query_override or thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
# Given back ahead of the documents for latency reasons
|
||||
# In chat flow it's given back along with the documents
|
||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||
|
||||
prompt = None
|
||||
if query_req.prompt_id is not None:
|
||||
# NOTE: let the user access any prompt as long as the Persona is shared
|
||||
# with them
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||
)
|
||||
if prompt is None:
|
||||
if not persona.prompts:
|
||||
raise RuntimeError(
|
||||
"Persona does not have any prompts - this should never happen"
|
||||
)
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
user_message_str = query_msg.message
|
||||
# For this endpoint, we only save one user message to the chat session
|
||||
# However, for slackbot, we want to include the history of the entire thread
|
||||
if danswerbot_flow:
|
||||
# Right now, we only support bringing over citations and search docs
|
||||
# from the last message in the thread, not the entire thread
|
||||
# in the future, we may want to retrieve the entire thread
|
||||
user_message_str = slackify_message_thread(query_req.messages)
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=user_message_str,
|
||||
token_count=len(llm_tokenizer.encode(user_message_str)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type
|
||||
),
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(
|
||||
get_llms_for_persona(persona=persona, long_term_logger=long_term_logger)
|
||||
),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool] if search_tool else [],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
# won't be any FileChatDisplay responses since that tool is never passed in
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
# (likely fine that it comes after the initial creation of the search docs)
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
top_docs = chunks_or_sections_to_search_docs(
|
||||
search_response_summary.top_sections
|
||||
)
|
||||
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
deduped_docs = top_docs
|
||||
if query_req.retrieval_options.dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=search_response_summary.predicted_flow,
|
||||
predicted_search=search_response_summary.predicted_search,
|
||||
applied_source_filters=search_response_summary.final_filters.source_type,
|
||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
document_based_response = {}
|
||||
|
||||
if packet.response is not None:
|
||||
for evaluation in packet.response:
|
||||
document_based_response[
|
||||
evaluation.document_id
|
||||
] = RelevanceAnalysis(
|
||||
relevant=evaluation.relevant, content=evaluation.content
|
||||
)
|
||||
|
||||
evaluation_response = DocumentRelevance(
|
||||
relevance_summaries=document_based_response
|
||||
)
|
||||
if reference_db_search_docs is not None:
|
||||
update_search_docs_table_with_relevance(
|
||||
db_session=db_session,
|
||||
reference_db_search_docs=reference_db_search_docs,
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
|
||||
else:
|
||||
yield packet
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=answer.llm_answer,
|
||||
token_count=len(llm_tokenizer.encode(answer.llm_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
yield msg_detail_response
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as session:
|
||||
objects = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=session,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
def get_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
enable_reflexion: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> OneShotQAResponse:
|
||||
"""Collects the streamed one shot answer responses into a single object"""
|
||||
qa_response = OneShotQAResponse()
|
||||
|
||||
results = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
timeout=answer_generation_timeout,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
answer = ""
|
||||
for packet in results:
|
||||
if isinstance(packet, QueryRephrase):
|
||||
qa_response.rephrase = packet.rephrased_query
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
qa_response.docs = packet
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, DanswerQuotes):
|
||||
qa_response.quotes = packet
|
||||
elif isinstance(packet, CitationInfo):
|
||||
if qa_response.citations:
|
||||
qa_response.citations.append(packet)
|
||||
else:
|
||||
qa_response.citations = [packet]
|
||||
elif isinstance(packet, DanswerContexts):
|
||||
qa_response.contexts = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
qa_response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
qa_response.chat_message_id = packet.message_id
|
||||
|
||||
if answer:
|
||||
qa_response.answer = answer
|
||||
|
||||
if enable_reflexion:
|
||||
# Because follow up messages are explicitly tagged, we don't need to verify the answer
|
||||
if len(query_req.messages) == 1:
|
||||
first_query = query_req.messages[0].message
|
||||
qa_response.answer_valid = get_answer_validity(first_query, answer)
|
||||
else:
|
||||
qa_response.answer_valid = True
|
||||
|
||||
if use_citations and qa_response.answer and qa_response.citations:
|
||||
# Reorganize citation nums to be in the same order as the answer
|
||||
qa_response.answer, qa_response.citations = reorganize_citations(
|
||||
qa_response.answer, qa_response.citations
|
||||
)
|
||||
|
||||
return qa_response
|
||||
114
backend/danswer/one_shot_answer/models.py
Normal file
114
backend/danswer/one_shot_answer/models.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import ChunkContext
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
|
||||
|
||||
class QueryRephrase(BaseModel):
|
||||
rephrased_query: str
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
persona_config: PersonaConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None = None
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
|
||||
chain_of_thought: bool = False
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "DirectQARequest":
|
||||
if (self.persona_config is None) == (self.persona_id is None):
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
raise ValueError(
|
||||
"If chain_of_thought is True, prompt_id must be None"
|
||||
"The chain of thought prompt is only for question "
|
||||
"answering and does not accept customizing."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
rephrase: str | None = None
|
||||
quotes: DanswerQuotes | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
chat_message_id: int | None = None
|
||||
contexts: DanswerContexts | None = None
|
||||
81
backend/danswer/one_shot_answer/qa_utils.py
Normal file
81
backend/danswer/one_shot_answer/qa_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
"""Mock streaming by generating the passed in model output, character by character"""
|
||||
for token in model_out:
|
||||
yield token
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def slackify_message(message: ThreadMessage) -> str:
|
||||
if message.role != MessageType.USER:
|
||||
return message.message
|
||||
|
||||
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"DanswerBot said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
@@ -106,7 +105,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.models import Document as DbDocument
|
||||
@@ -115,7 +114,7 @@ class RedisConnectorDelete:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -12,7 +12,6 @@ from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
@@ -133,8 +132,6 @@ class RedisConnectorPermissionSync:
|
||||
lock: RedisLock | None,
|
||||
new_permissions: list[DocExternalAccess],
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
async_results = []
|
||||
@@ -152,13 +149,11 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
"update_external_document_permissions_task",
|
||||
kwargs=dict(
|
||||
tenant_id=self.tenant_id,
|
||||
serialized_doc_external_access=doc_perm.to_dict(),
|
||||
source_string=source_string,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
|
||||
|
||||
@@ -135,7 +134,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
@@ -77,7 +76,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -90,7 +89,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -20,7 +20,6 @@ from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.connectors.google_utils.google_auth import (
|
||||
@@ -868,7 +867,7 @@ def connector_run_once(
|
||||
|
||||
# run the beat task to pick up the triggers immediately
|
||||
primary_app.send_task(
|
||||
DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"check_for_indexing",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -200,7 +199,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
|
||||
# run the beat task to pick up this deletion from the db immediately
|
||||
primary_app.send_task(
|
||||
DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"check_for_connector_deletion_task",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +15,7 @@ from danswer.danswerbot.slack.config import VALID_SLACK_FILTERS
|
||||
from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBot as SlackAppModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig as SlackChannelConfigModel
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -148,12 +148,6 @@ class SlackBotTokens(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
# TODO No longer in use, remove later
|
||||
class SlackBotResponseType(str, Enum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
|
||||
|
||||
class SlackChannelConfigCreationRequest(BaseModel):
|
||||
slack_bot_id: int
|
||||
# currently, a persona is created for each Slack channel config
|
||||
@@ -203,6 +197,7 @@ class SlackChannelConfig(BaseModel):
|
||||
id: int
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
@@ -222,6 +217,7 @@ class SlackChannelConfig(BaseModel):
|
||||
else None
|
||||
),
|
||||
channel_config=slack_channel_config_model.channel_config,
|
||||
response_type=slack_channel_config_model.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
|
||||
@@ -118,6 +118,7 @@ def create_slack_channel_config(
|
||||
slack_bot_id=slack_channel_config_creation_request.slack_bot_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_channel_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
@@ -181,6 +182,7 @@ def patch_slack_channel_config(
|
||||
slack_channel_config_id=slack_channel_config_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_channel_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from danswer.auth.noauth_user import fetch_no_auth_user
|
||||
from danswer.auth.noauth_user import set_no_auth_user_preferences
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserStatus
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
@@ -33,6 +34,7 @@ from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.api_key import is_api_key_email_address
|
||||
@@ -59,11 +61,9 @@ from danswer.server.manage.models import UserRoleUpdateRequest
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.server.utils import send_user_email_invite
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.configs.app_configs import SUPER_USERS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -109,7 +109,6 @@ def process_run_in_background(
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
|
||||
@@ -707,18 +707,14 @@ def upload_files_for_chat(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/file/{file_id:path}")
|
||||
@router.get("/file/{file_id}")
|
||||
def fetch_chat_file(
|
||||
file_id: str,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> Response:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
return StreamingResponse(file_io, media_type=media_type)
|
||||
# NOTE: specifying "image/jpeg" here, but it still works for pngs
|
||||
# TODO: do this properly
|
||||
return Response(content=file_io.read(), media_type="image/jpeg")
|
||||
|
||||
@@ -5,14 +5,12 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import RetrievalDocs
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import ChunkContext
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.context.search.models import Tag
|
||||
@@ -89,8 +87,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None
|
||||
retrieval_options: RetrievalDetails | None
|
||||
# Useable via the APIs but not recommended for most flows
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
@@ -106,10 +102,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# allow user to specify an alternate assistnat
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
# This takes the priority over the prompt_override
|
||||
# This won't be a type that's passed in directly from the API
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
@@ -153,7 +145,7 @@ class RenameChatSessionResponse(BaseModel):
|
||||
|
||||
class ChatSessionDetails(BaseModel):
|
||||
id: UUID
|
||||
name: str | None
|
||||
name: str
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
@@ -206,14 +198,14 @@ class ChatMessageDetail(BaseModel):
|
||||
|
||||
class SearchSessionDetailResponse(BaseModel):
|
||||
search_session_id: UUID
|
||||
description: str | None
|
||||
description: str
|
||||
documents: list[SearchDoc]
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: UUID
|
||||
description: str | None
|
||||
description: str
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -28,6 +32,8 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tag import find_tags
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.one_shot_answer.answer_question import stream_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.server.query_and_chat.models import AdminSearchRequest
|
||||
from danswer.server.query_and_chat.models import AdminSearchResponse
|
||||
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||
@@ -35,6 +41,7 @@ from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
from danswer.server.query_and_chat.models import SearchSessionDetailResponse
|
||||
from danswer.server.query_and_chat.models import SourceTag
|
||||
from danswer.server.query_and_chat.models import TagResponse
|
||||
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -133,7 +140,7 @@ def get_user_search_sessions(
|
||||
|
||||
try:
|
||||
search_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session
|
||||
user_id=user_id, deleted=False, db_session=db_session, only_one_shot=True
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
@@ -222,3 +229,29 @@ def get_search_session(
|
||||
],
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@basic_router.post("/stream-answer-with-quote")
|
||||
def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User = Depends(current_limited_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
query = query_request.messages[0].message
|
||||
|
||||
logger.notice(f"Received query for one shot answer with quotes: {query}")
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_search_answer(
|
||||
query_req=query_request,
|
||||
user=user,
|
||||
max_document_tokens=None,
|
||||
max_history_tokens=0,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
except Exception as e:
|
||||
logger.exception("Error in search answer streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
|
||||
@@ -6,9 +6,6 @@ from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
@@ -17,11 +14,6 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
@@ -103,14 +102,11 @@ class SearchToolConfig(BaseModel):
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
full_doc: bool = False
|
||||
latest_query_files: list[InMemoryChatFile] | None = None
|
||||
# Use with care, should only be used for DanswerBot in channels with multiple users
|
||||
bypass_acl: bool = False
|
||||
|
||||
|
||||
class InternetSearchToolConfig(BaseModel):
|
||||
@@ -174,8 +170,6 @@ def construct_tools(
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc:
|
||||
updated_at=datetime.now(),
|
||||
link=result.link,
|
||||
source_links={0: result.link},
|
||||
match_highlights=[],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
@@ -104,7 +103,6 @@ class SearchTool(Tool):
|
||||
chunks_below: int | None = None,
|
||||
full_doc: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> None:
|
||||
self.user = user
|
||||
self.persona = persona
|
||||
@@ -120,9 +118,6 @@ class SearchTool(Tool):
|
||||
self.bypass_acl = bypass_acl
|
||||
self.db_session = db_session
|
||||
|
||||
# Only used via API
|
||||
self.rerank_settings = rerank_settings
|
||||
|
||||
self.chunks_above = (
|
||||
chunks_above
|
||||
if chunks_above is not None
|
||||
@@ -297,7 +292,6 @@ class SearchTool(Tool):
|
||||
self.retrieval_options.offset if self.retrieval_options else None
|
||||
),
|
||||
limit=self.retrieval_options.limit if self.retrieval_options else None,
|
||||
rerank_settings=self.rerank_settings,
|
||||
chunks_above=self.chunks_above,
|
||||
chunks_below=self.chunks_below,
|
||||
full_doc=self.full_doc,
|
||||
|
||||
@@ -1,72 +1,23 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from ee.danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from ee.danswer.configs.app_configs import SUPER_USERS
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.server.seeding import get_seed_config
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_public_key() -> str | None:
|
||||
if JWT_PUBLIC_KEY_URL is None:
|
||||
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
||||
return None
|
||||
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
try:
|
||||
public_key_pem = get_public_key()
|
||||
if public_key_pem is None:
|
||||
logger.error("Failed to retrieve public key")
|
||||
return None
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key_pem,
|
||||
algorithms=["RS256"],
|
||||
audience=None,
|
||||
)
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
except InvalidTokenError:
|
||||
logger.error("Invalid JWT token")
|
||||
get_public_key.cache_clear()
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
get_public_key.cache_clear()
|
||||
return None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
@@ -87,13 +38,6 @@ async def optional_user_(
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# If user is still None, check for JWT in Authorization header
|
||||
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
user = await verify_jwt_token(token, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -4,17 +4,16 @@ from typing import Any
|
||||
from danswer.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": DanswerCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"task": "autogenerate_usage_report_task",
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": DanswerCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"task": "check_ttl_management_task",
|
||||
"schedule": timedelta(hours=1),
|
||||
},
|
||||
]
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.process_message import ChatPacketStream
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.utils.timing import log_function_time
|
||||
from ee.danswer.server.query_and_chat.models import OneShotQAResponse
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_answer_api(
|
||||
packets: ChatPacketStream,
|
||||
) -> OneShotQAResponse:
|
||||
response = OneShotQAResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
# Extraneous, provided for backwards compatibility
|
||||
response.rephrase = packet.rephrased_query
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
elif isinstance(packet, DanswerContexts):
|
||||
response.contexts = packet
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
# Applicable for OIDC Auth
|
||||
@@ -20,11 +19,3 @@ STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
|
||||
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
|
||||
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
@@ -155,6 +155,7 @@ def _handle_standard_answers(
|
||||
else 0,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
one_shot=True,
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
|
||||
@@ -37,15 +37,10 @@ def get_cc_pairs_by_source(
|
||||
source_type: DocumentSource,
|
||||
only_sync: bool,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""
|
||||
Get all cc_pairs for a given source type (and optionally only sync)
|
||||
result is sorted by cc_pair id
|
||||
"""
|
||||
query = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.join(ConnectorCredentialPair.connector)
|
||||
.filter(Connector.source == source_type)
|
||||
.order_by(ConnectorCredentialPair.id)
|
||||
)
|
||||
|
||||
if only_sync:
|
||||
|
||||
@@ -55,10 +55,9 @@ def upsert_document_external_perms(
|
||||
doc_id: str,
|
||||
external_access: ExternalAccess,
|
||||
source_type: DocumentSource,
|
||||
) -> bool:
|
||||
) -> None:
|
||||
"""
|
||||
This sets the permissions for a document in postgres. Returns True if the
|
||||
a new document was created, False otherwise.
|
||||
This sets the permissions for a document in postgres.
|
||||
NOTE: this will replace any existing external access, it will not do a union
|
||||
"""
|
||||
document = db_session.scalars(
|
||||
@@ -86,7 +85,7 @@ def upsert_document_external_perms(
|
||||
)
|
||||
db_session.add(document)
|
||||
db_session.commit()
|
||||
return True
|
||||
return
|
||||
|
||||
# If the document exists, we need to check if the external access has changed
|
||||
if (
|
||||
@@ -99,5 +98,3 @@ def upsert_document_external_perms(
|
||||
document.is_public = external_access.is_public
|
||||
document.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
return False
|
||||
|
||||
@@ -33,7 +33,12 @@ def get_empty_chat_messages_entries__paginated(
|
||||
|
||||
message_skeletons: list[ChatMessageSkeleton] = []
|
||||
for chat_session in chat_sessions:
|
||||
flow_type = FlowType.SLACK if chat_session.danswerbot_flow else FlowType.CHAT
|
||||
if chat_session.one_shot:
|
||||
flow_type = FlowType.SEARCH
|
||||
elif chat_session.danswerbot_flow:
|
||||
flow_type = FlowType.SLACK
|
||||
else:
|
||||
flow_type = FlowType.CHAT
|
||||
|
||||
for message in chat_session.messages:
|
||||
# Only count user messages
|
||||
|
||||
@@ -48,11 +48,6 @@ GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
}
|
||||
|
||||
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
DocumentSource.CONFLUENCE,
|
||||
}
|
||||
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
@@ -62,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
# Polling is not supported so we fetch all group permissions every 5 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.CONFLUENCE: 30 * 60,
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import combine_message_thread
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
@@ -17,8 +16,8 @@ from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.process_message import ChatPacketStream
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.context.search.models import OptionalSearchSetting
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SavedSearchDoc
|
||||
@@ -30,6 +29,7 @@ from danswer.db.models import User
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
@@ -171,8 +171,6 @@ def handle_simplified_chat_message(
|
||||
prompt_id=None,
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=chat_message_req.query_override,
|
||||
# Currently only applies to search flow not chat
|
||||
chunks_above=0,
|
||||
@@ -234,6 +232,7 @@ def handle_send_message_simple_with_history(
|
||||
description="handle_send_message_simple_with_history",
|
||||
user_id=user_id,
|
||||
persona_id=req.persona_id,
|
||||
one_shot=False,
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona)
|
||||
@@ -246,7 +245,7 @@ def handle_send_message_simple_with_history(
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name, model_provider=llm.config.model_provider
|
||||
)
|
||||
max_history_tokens = int(input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
|
||||
max_history_tokens = int(input_tokens * DANSWER_BOT_TARGET_CHUNK_PERCENTAGE)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
@@ -294,8 +293,6 @@ def handle_send_message_simple_with_history(
|
||||
prompt_id=req.prompt_id,
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=rephrased_query,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
|
||||
@@ -2,13 +2,7 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import SearchType
|
||||
@@ -16,6 +10,7 @@ from danswer.context.search.models import ChunkContext
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -101,48 +96,3 @@ class ChatBasicResponse(BaseModel):
|
||||
# TODO: deprecate both of these
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
# Easier APIs to work with for developers
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
elif self.persona_override_config is not None and (
|
||||
self.persona_id is not None or self.prompt_id is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"If persona_override_config is set, persona_id and prompt_id cannot be set"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
rephrase: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
contexts: DanswerContexts | None = None
|
||||
|
||||
@@ -1,47 +1,38 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import combine_message_thread
|
||||
from danswer.chat.chat_utils import prepare_chat_message_request
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.process_message import ChatPacketStream
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.context.search.models import SavedSearchDocWithContent
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.context.search.utils import dedupe_documents
|
||||
from danswer.context.search.utils import drop_llm_indices
|
||||
from danswer.context.search.utils import relevant_sections_to_indices
|
||||
from danswer.db.chat import get_prompt_by_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.chat.process_message import gather_stream_for_answer_api
|
||||
from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
oneoff_standard_answers,
|
||||
)
|
||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||
from ee.danswer.server.query_and_chat.models import OneShotQARequest
|
||||
from ee.danswer.server.query_and_chat.models import OneShotQAResponse
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -134,115 +125,58 @@ def handle_search_request(
|
||||
return DocumentSearchResponse(top_documents=deduped_docs, llm_indices=llm_indices)
|
||||
|
||||
|
||||
def get_answer_stream(
|
||||
query_request: OneShotQARequest,
|
||||
@basic_router.post("/answer-with-quote")
|
||||
def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatPacketStream:
|
||||
) -> OneShotQAResponse:
|
||||
query = query_request.messages[0].message
|
||||
logger.notice(f"Received query for Answer API: {query}")
|
||||
logger.notice(f"Received query for one shot answer API with quotes: {query}")
|
||||
|
||||
if (
|
||||
query_request.persona_override_config is None
|
||||
and query_request.persona_id is None
|
||||
):
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
prompt = None
|
||||
if query_request.prompt_id is not None:
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_request.prompt_id,
|
||||
user=user,
|
||||
if query_request.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=query_request.persona_config,
|
||||
user=user,
|
||||
)
|
||||
persona = new_persona
|
||||
|
||||
persona_info: Persona | PersonaOverrideConfig | None = None
|
||||
if query_request.persona_override_config is not None:
|
||||
persona_info = query_request.persona_override_config
|
||||
elif query_request.persona_id is not None:
|
||||
persona_info = get_persona_by_id(
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
llm = get_main_llm_from_tuple(get_llms_for_persona(persona_info))
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
llm = get_main_llm_from_tuple(
|
||||
get_default_llms() if not persona else get_llms_for_persona(persona)
|
||||
)
|
||||
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name, model_provider=llm.config.model_provider
|
||||
)
|
||||
max_history_tokens = int(input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE)
|
||||
max_history_tokens = int(input_tokens * DANSWER_BOT_TARGET_CHUNK_PERCENTAGE)
|
||||
|
||||
combined_message = combine_message_thread(
|
||||
messages=query_request.messages,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
|
||||
# Also creates a new chat session
|
||||
request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
answer_details = get_search_answer(
|
||||
query_req=query_request,
|
||||
user=user,
|
||||
persona_id=query_request.persona_id,
|
||||
persona_override_config=query_request.persona_override_config,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
include_contexts=query_request.return_contexts,
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
@basic_router.post("/answer-with-citation")
|
||||
def get_answer_with_citation(
|
||||
request: OneShotQARequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> OneShotQAResponse:
|
||||
try:
|
||||
packets = get_answer_stream(request, user, db_session)
|
||||
answer = gather_stream_for_answer_api(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_answer_with_citation: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="An internal server error occurred")
|
||||
|
||||
|
||||
@basic_router.post("/stream-answer-with-citation")
|
||||
def stream_answer_with_citation(
|
||||
request: OneShotQARequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in get_answer_stream(request, user, db_session):
|
||||
serialized = get_json_line(packet.model_dump())
|
||||
yield serialized
|
||||
except Exception as e:
|
||||
logger.exception("Error in answer streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
return answer_details
|
||||
|
||||
|
||||
@basic_router.get("/standard-answer")
|
||||
|
||||
85
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
85
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.one_shot_answer.models import PersonaConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
@@ -179,7 +179,13 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
|
||||
|
||||
def determine_flow_type(chat_session: ChatSession) -> SessionType:
|
||||
return SessionType.SLACK if chat_session.danswerbot_flow else SessionType.CHAT
|
||||
return (
|
||||
SessionType.SLACK
|
||||
if chat_session.danswerbot_flow
|
||||
else SessionType.SEARCH
|
||||
if chat_session.one_shot
|
||||
else SessionType.CHAT
|
||||
)
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history_minimal(
|
||||
|
||||
@@ -9,6 +9,7 @@ from danswer.auth.schemas import UserStatus
|
||||
|
||||
class FlowType(str, Enum):
|
||||
CHAT = "chat"
|
||||
SEARCH = "search"
|
||||
SLACK = "slack"
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ cohere==5.6.1
|
||||
fastapi==0.109.2
|
||||
google-cloud-aiplatform==1.58.0
|
||||
numpy==1.26.4
|
||||
openai==1.55.3
|
||||
openai==1.52.2
|
||||
pydantic==2.8.2
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Basic Usage:
|
||||
|
||||
python scripts/chat_loadtest.py --api-key <api-key> --url <danswer-url>/api
|
||||
|
||||
to run from the container itself, copy this file in and run:
|
||||
|
||||
python chat_loadtest.py --api-key <api-key> --url localhost:8080
|
||||
|
||||
For more options, checkout the bottom of the file.
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import statistics
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from logging import getLogger
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMetrics:
|
||||
session_id: UUID
|
||||
total_time: float
|
||||
first_doc_time: float
|
||||
first_answer_time: float
|
||||
tokens_per_second: float
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatLoadTester:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str | None,
|
||||
num_concurrent: int,
|
||||
messages_per_session: int,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
self.num_concurrent = num_concurrent
|
||||
self.messages_per_session = messages_per_session
|
||||
self.metrics: list[ChatMetrics] = []
|
||||
|
||||
async def create_chat_session(self, session: aiohttp.ClientSession) -> str:
|
||||
"""Create a new chat session"""
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/create-chat-session",
|
||||
headers=self.headers,
|
||||
json={"persona_id": 0, "description": "Load Test"},
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
return data["chat_session_id"]
|
||||
|
||||
async def process_stream(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Process the SSE stream from the chat response"""
|
||||
async for chunk in response.content:
|
||||
chunk_str = chunk.decode()
|
||||
yield chunk_str
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
chat_session_id: str,
|
||||
message: str,
|
||||
parent_message_id: int | None = None,
|
||||
) -> ChatMetrics:
|
||||
"""Send a message and measure performance metrics"""
|
||||
start_time = time.time()
|
||||
first_doc_time = None
|
||||
first_answer_time = None
|
||||
token_count = 0
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/send-message",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"chat_session_id": chat_session_id,
|
||||
"message": message,
|
||||
"parent_message_id": parent_message_id,
|
||||
"prompt_id": None,
|
||||
"retrieval_options": {
|
||||
"run_search": "always",
|
||||
"real_time": True,
|
||||
},
|
||||
"file_descriptors": [],
|
||||
"search_doc_ids": [],
|
||||
},
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for chunk in self.process_stream(response):
|
||||
if "tool_name" in chunk and "run_search" in chunk:
|
||||
if first_doc_time is None:
|
||||
first_doc_time = time.time() - start_time
|
||||
|
||||
if "answer_piece" in chunk:
|
||||
if first_answer_time is None:
|
||||
first_answer_time = time.time() - start_time
|
||||
token_count += 1
|
||||
|
||||
total_time = time.time() - start_time
|
||||
tokens_per_second = token_count / total_time if total_time > 0 else 0
|
||||
|
||||
return ChatMetrics(
|
||||
session_id=UUID(chat_session_id),
|
||||
total_time=total_time,
|
||||
first_doc_time=first_doc_time or 0,
|
||||
first_answer_time=first_answer_time or 0,
|
||||
tokens_per_second=tokens_per_second,
|
||||
total_tokens=token_count,
|
||||
)
|
||||
|
||||
async def run_chat_session(self) -> None:
|
||||
"""Run a complete chat session with multiple messages"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
chat_session_id = await self.create_chat_session(session)
|
||||
messages = [
|
||||
"Tell me about the key features of the product",
|
||||
"How does the search functionality work?",
|
||||
"What are the deployment options?",
|
||||
"Can you explain the security features?",
|
||||
"What integrations are available?",
|
||||
]
|
||||
|
||||
parent_message_id = None
|
||||
for i in range(self.messages_per_session):
|
||||
message = messages[i % len(messages)]
|
||||
metrics = await self.send_message(
|
||||
session, chat_session_id, message, parent_message_id
|
||||
)
|
||||
self.metrics.append(metrics)
|
||||
parent_message_id = metrics.total_tokens # Simplified for example
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat session: {e}")
|
||||
|
||||
async def run_load_test(self) -> None:
|
||||
"""Run multiple concurrent chat sessions"""
|
||||
start_time = time.time()
|
||||
tasks = [self.run_chat_session() for _ in range(self.num_concurrent)]
|
||||
await asyncio.gather(*tasks)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
self.print_results(total_time)
|
||||
|
||||
def print_results(self, total_time: float) -> None:
|
||||
"""Print load test results and metrics"""
|
||||
logger.info("\n=== Load Test Results ===")
|
||||
logger.info(f"Total Time: {total_time:.2f} seconds")
|
||||
logger.info(f"Concurrent Sessions: {self.num_concurrent}")
|
||||
logger.info(f"Messages per Session: {self.messages_per_session}")
|
||||
logger.info(f"Total Messages: {len(self.metrics)}")
|
||||
|
||||
if self.metrics:
|
||||
avg_response_time = statistics.mean(m.total_time for m in self.metrics)
|
||||
avg_first_doc = statistics.mean(m.first_doc_time for m in self.metrics)
|
||||
avg_first_answer = statistics.mean(
|
||||
m.first_answer_time for m in self.metrics
|
||||
)
|
||||
avg_tokens_per_sec = statistics.mean(
|
||||
m.tokens_per_second for m in self.metrics
|
||||
)
|
||||
|
||||
logger.info(f"\nAverage Response Time: {avg_response_time:.2f} seconds")
|
||||
logger.info(f"Average Time to Documents: {avg_first_doc:.2f} seconds")
|
||||
logger.info(f"Average Time to First Answer: {avg_first_answer:.2f} seconds")
|
||||
logger.info(f"Average Tokens/Second: {avg_tokens_per_sec:.2f}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Chat Load Testing Tool")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="http://localhost:3000/api",
|
||||
help="Danswer URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
help="Danswer Basic/Admin Level API key",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrent",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of concurrent chat sessions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--messages",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of messages per chat session",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
load_tester = ChatLoadTester(
|
||||
base_url=args.url,
|
||||
api_key=args.api_key,
|
||||
num_concurrent=args.concurrent,
|
||||
messages_per_session=args.messages,
|
||||
)
|
||||
|
||||
asyncio.run(load_tester.run_load_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,79 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# makes it so `PYTHONPATH=.` is not required when running this script
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from danswer.db.engine import get_session_context_manager # noqa: E402
|
||||
from danswer.db.document import delete_documents_complete__no_commit # noqa: E402
|
||||
from danswer.db.search_settings import get_current_search_settings # noqa: E402
|
||||
from danswer.document_index.vespa.index import VespaIndex # noqa: E402
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import ( # noqa: E402
|
||||
RetryDocumentIndex,
|
||||
)
|
||||
|
||||
|
||||
def _get_orphaned_document_ids(db_session: Session) -> list[str]:
|
||||
"""Get document IDs that don't have any entries in document_by_connector_credential_pair"""
|
||||
query = text(
|
||||
"""
|
||||
SELECT d.id
|
||||
FROM document d
|
||||
LEFT JOIN document_by_connector_credential_pair dbcc ON d.id = dbcc.id
|
||||
WHERE dbcc.id IS NULL
|
||||
"""
|
||||
)
|
||||
orphaned_ids = [doc_id[0] for doc_id in db_session.execute(query)]
|
||||
print(f"Found {len(orphaned_ids)} orphaned documents")
|
||||
return orphaned_ids
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
# Get orphaned document IDs
|
||||
orphaned_ids = _get_orphaned_document_ids(db_session)
|
||||
if not orphaned_ids:
|
||||
print("No orphaned documents found")
|
||||
return
|
||||
|
||||
# Setup Vespa index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_name = search_settings.index_name
|
||||
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
|
||||
retry_index = RetryDocumentIndex(vespa_index)
|
||||
|
||||
# Delete chunks from Vespa first
|
||||
print("Deleting orphaned document chunks from Vespa")
|
||||
successfully_vespa_deleted_doc_ids = []
|
||||
for doc_id in orphaned_ids:
|
||||
try:
|
||||
chunks_deleted = retry_index.delete_single(doc_id)
|
||||
successfully_vespa_deleted_doc_ids.append(doc_id)
|
||||
if chunks_deleted > 0:
|
||||
print(f"Deleted {chunks_deleted} chunks for document {doc_id}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error deleting document {doc_id} in Vespa and will not delete from Postgres: {e}"
|
||||
)
|
||||
|
||||
# Delete documents from Postgres
|
||||
print("Deleting orphaned documents from Postgres")
|
||||
try:
|
||||
delete_documents_complete__no_commit(
|
||||
db_session, successfully_vespa_deleted_doc_ids
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
print(f"Error deleting documents from Postgres: {e}")
|
||||
|
||||
print(
|
||||
f"Successfully cleaned up {len(successfully_vespa_deleted_doc_ids)} orphaned documents"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -163,92 +163,47 @@ SUPPORTED_EMBEDDING_MODELS = [
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_cohere_embed_english_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="cohere/embed-english-v3.0",
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_embed_english_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="cohere/embed-english-light-v3.0",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_cohere_embed_english_light_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="cohere/embed-english-light-v3.0",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_embed_english_light_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-large",
|
||||
dim=3072,
|
||||
index_name="danswer_chunk_openai_text_embedding_3_large",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-large",
|
||||
dim=3072,
|
||||
index_name="danswer_chunk_text_embedding_3_large",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-small",
|
||||
dim=1536,
|
||||
index_name="danswer_chunk_openai_text_embedding_3_small",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-small",
|
||||
dim=1536,
|
||||
index_name="danswer_chunk_text_embedding_3_small",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/text-embedding-004",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_text_embedding_004",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/text-embedding-004",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_text_embedding_004",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_textembedding_gecko_003",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_textembedding_gecko_003",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-large-2-instruct",
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_voyage_large_2_instruct",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-large-2-instruct",
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_large_2_instruct",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-light-2-instruct",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_voyage_light_2_instruct",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-light-2-instruct",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_light_2_instruct",
|
||||
),
|
||||
# Self-hosted models
|
||||
SupportedEmbeddingModel(
|
||||
name="nomic-ai/nomic-embed-text-v1",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_nomic_ai_nomic_embed_text_v1",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="nomic-ai/nomic-embed-text-v1",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_nomic_embed_text_v1",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="intfloat/e5-base-v2",
|
||||
dim=768,
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.slab.connector import SlabConnector
|
||||
|
||||
|
||||
def load_test_data(file_name: str = "test_slab_data.json") -> dict[str, str]:
|
||||
current_dir = Path(__file__).parent
|
||||
with open(current_dir / file_name, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def slab_connector() -> SlabConnector:
|
||||
connector = SlabConnector(
|
||||
base_url="https://onyx-test.slab.com/",
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"slab_bot_token": os.environ["SLAB_BOT_TOKEN"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Need a test account with a slab subscription to run this test."
|
||||
"Trial only lasts 14 days."
|
||||
)
|
||||
)
|
||||
def test_slab_connector_basic(slab_connector: SlabConnector) -> None:
|
||||
all_docs: list[Document] = []
|
||||
target_test_doc_id = "jcp6cohu"
|
||||
target_test_doc: Document | None = None
|
||||
for doc_batch in slab_connector.poll_source(0, time.time()):
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_test_doc = doc
|
||||
|
||||
assert len(all_docs) == 6
|
||||
assert target_test_doc is not None
|
||||
|
||||
desired_test_data = load_test_data()
|
||||
assert (
|
||||
target_test_doc.semantic_identifier == desired_test_data["semantic_identifier"]
|
||||
)
|
||||
assert target_test_doc.source == DocumentSource.SLAB
|
||||
assert target_test_doc.metadata == {}
|
||||
assert target_test_doc.primary_owners is None
|
||||
assert target_test_doc.secondary_owners is None
|
||||
assert target_test_doc.title is None
|
||||
assert target_test_doc.from_ingestion_api is False
|
||||
assert target_test_doc.additional_info is None
|
||||
|
||||
assert len(target_test_doc.sections) == 1
|
||||
section = target_test_doc.sections[0]
|
||||
# Need to replace the weird apostrophe with a normal one
|
||||
assert section.text.replace("\u2019", "'") == desired_test_data["section_text"]
|
||||
assert section.link == desired_test_data["link"]
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Need a test account with a slab subscription to run this test."
|
||||
"Trial only lasts 14 days."
|
||||
)
|
||||
)
|
||||
def test_slab_connector_slim(slab_connector: SlabConnector) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
for doc_batch in slab_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
|
||||
# Get all doc IDs from the slim connector
|
||||
all_slim_doc_ids = set()
|
||||
for slim_doc_batch in slab_connector.retrieve_all_slim_documents():
|
||||
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
|
||||
|
||||
# The set of full doc IDs should be always be a subset of the slim doc IDs
|
||||
assert all_full_doc_ids.issubset(all_slim_doc_ids)
|
||||
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"section_text": "Learn about Posts\nWelcome\nThis is a post, where you can edit, share, and collaborate in real time with your team. We'd love to show you how it works!\nReading and editing\nClick the mode button to toggle between read and edit modes. You can only make changes to a post when editing.\nOrganize your posts\nWhen in edit mode, you can add topics to a post, which will keep it organized for the right 👀 to see.\nSmart mentions\nMentions are references to users, posts, topics and third party tools that show details on hover. Paste in a link for automatic conversion.\nLook back in time\nYou are ready to begin writing. You can always bring back this tour in the help menu.\nGreat job!\nYou are ready to begin writing. You can always bring back this tour in the help menu.\n\n",
|
||||
"link": "https://onyx-test.slab.com/posts/learn-about-posts-jcp6cohu",
|
||||
"semantic_identifier": "Learn about Posts"
|
||||
}
|
||||
@@ -8,6 +8,8 @@ from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -66,7 +68,6 @@ class ChatSessionManager:
|
||||
prompt_id=prompt_id,
|
||||
search_doc_ids=search_doc_ids or [],
|
||||
retrieval_options=retrieval_options,
|
||||
rerank_settings=None, # Can be added if needed
|
||||
query_override=query_override,
|
||||
regenerate=regenerate,
|
||||
llm_override=llm_override,
|
||||
@@ -86,6 +87,30 @@ class ChatSessionManager:
|
||||
|
||||
return ChatSessionManager.analyze_response(response)
|
||||
|
||||
@staticmethod
|
||||
def get_answer_with_quote(
|
||||
persona_id: int,
|
||||
message: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> StreamedResponse:
|
||||
direct_qa_request = DirectQARequest(
|
||||
messages=[ThreadMessage(message=message)],
|
||||
prompt_id=None,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/query/stream-answer-with-quote",
|
||||
json=direct_qa_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
stream=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ChatSessionManager.analyze_response(response)
|
||||
|
||||
@staticmethod
|
||||
def analyze_response(response: Response) -> StreamedResponse:
|
||||
response_data = [
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import mimetypes
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class FileManager:
|
||||
@staticmethod
|
||||
def upload_files(
|
||||
files: List[Tuple[str, IO]],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> Tuple[List[FileDescriptor], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
files_param = []
|
||||
for filename, file_obj in files:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
files_param.append(("files", (filename, file_obj, mime_type)))
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/file",
|
||||
files=files_param,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
return (
|
||||
cast(List[FileDescriptor], []),
|
||||
f"Failed to upload files - {response.json().get('detail', 'Unknown error')}",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
return response_json.get("files", cast(List[FileDescriptor], [])), ""
|
||||
|
||||
@staticmethod
|
||||
def fetch_uploaded_file(
|
||||
file_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bytes:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/file/{file_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
@@ -44,7 +44,6 @@ def test_persona_category_management(reset: None) -> None:
|
||||
category=updated_persona_category,
|
||||
user_performing_action=regular_user,
|
||||
)
|
||||
assert exc_info.value.response is not None
|
||||
assert exc_info.value.response.status_code == 403
|
||||
|
||||
assert PersonaCategoryManager.verify(
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
|
||||
response = ChatSessionManager.get_answer_with_quote(
|
||||
persona_id=test_chat_session.persona_id,
|
||||
message="Hello, this is a test.",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.tool_name is not None
|
||||
), "Tool name should be specified (always search)"
|
||||
assert (
|
||||
response.relevance_summaries is not None
|
||||
), "Relevance summaries should be present for all search streams"
|
||||
assert len(response.full_message) > 0, "Response message should not be empty"
|
||||
@@ -1,16 +1,16 @@
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.context.search.models import OptionalSearchSetting
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from ee.danswer.server.query_and_chat.models import OneShotQARequest
|
||||
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
@@ -37,7 +37,7 @@ def get_answer_from_query(
|
||||
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
new_message_request = OneShotQARequest(
|
||||
new_message_request = DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=0,
|
||||
persona_id=0,
|
||||
@@ -47,11 +47,12 @@ def get_answer_from_query(
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
),
|
||||
chain_of_thought=False,
|
||||
return_contexts=True,
|
||||
skip_gen_ai_answer_generation=only_retrieve_docs,
|
||||
)
|
||||
|
||||
url = _api_url_builder(env_name, "/query/answer-with-citation/")
|
||||
url = _api_url_builder(env_name, "/query/answer-with-quote/")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@@ -64,7 +64,6 @@ def mock_search_results() -> list[LlmDoc]:
|
||||
updated_at=datetime(2023, 1, 1),
|
||||
link="https://example.com/doc1",
|
||||
source_links={0: "https://example.com/doc1"},
|
||||
match_highlights=[],
|
||||
),
|
||||
LlmDoc(
|
||||
content="Search result 2",
|
||||
@@ -76,7 +75,6 @@ def mock_search_results() -> list[LlmDoc]:
|
||||
updated_at=datetime(2023, 1, 2),
|
||||
link="https://example.com/doc2",
|
||||
source_links={0: "https://example.com/doc2"},
|
||||
match_highlights=[],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ mock_docs = [
|
||||
updated_at=datetime.now(),
|
||||
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
|
||||
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
|
||||
match_highlights=[],
|
||||
)
|
||||
for id in range(10)
|
||||
]
|
||||
@@ -386,16 +385,6 @@ def process_text(
|
||||
"Here is some text[[1]](https://0.com). Some other text",
|
||||
["doc_0"],
|
||||
),
|
||||
# ['To', ' set', ' up', ' D', 'answer', ',', ' if', ' you', ' are', ' running', ' it', ' yourself', ' and',
|
||||
# ' need', ' access', ' to', ' certain', ' features', ' like', ' auto', '-sync', 'ing', ' document',
|
||||
# '-level', ' access', ' permissions', ',', ' you', ' should', ' reach', ' out', ' to', ' the', ' D',
|
||||
# 'answer', ' team', ' to', ' receive', ' access', ' [[', '4', ']].', '']
|
||||
(
|
||||
"Unique tokens with double brackets and a single token that ends the citation and has characters after it.",
|
||||
["... to receive access", " [[", "1", "]].", ""],
|
||||
"... to receive access [[1]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_citation_extraction(
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
|
||||
mock_docs = [
|
||||
LlmDoc(
|
||||
document_id=f"doc_{int(id/2)}",
|
||||
content="Document is a doc",
|
||||
blurb=f"Document #{id}",
|
||||
semantic_identifier=f"Doc {id}",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
updated_at=datetime.now(),
|
||||
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
|
||||
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
|
||||
)
|
||||
for id in range(10)
|
||||
]
|
||||
|
||||
|
||||
def _process_tokens(
|
||||
processor: QuotesProcessor, tokens: list[str]
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Process a list of tokens and return the answer and quotes.
|
||||
|
||||
Args:
|
||||
processor: QuotesProcessor instance
|
||||
tokens: List of tokens to process
|
||||
|
||||
Returns:
|
||||
Tuple of (answer_text, list_of_quotes)
|
||||
"""
|
||||
answer = ""
|
||||
quotes: list[str] = []
|
||||
|
||||
# need to add a None to the end to simulate the end of the stream
|
||||
for token in tokens + [None]:
|
||||
for output in processor.process_token(token):
|
||||
if isinstance(output, DanswerAnswerPiece):
|
||||
if output.answer_piece:
|
||||
answer += output.answer_piece
|
||||
elif isinstance(output, DanswerQuotes):
|
||||
quotes.extend(q.quote for q in output.quotes)
|
||||
|
||||
return answer, quotes
|
||||
|
||||
|
||||
def test_process_model_tokens_answer() -> None:
|
||||
tokens_with_quotes = [
|
||||
"{",
|
||||
"\n ",
|
||||
'"answer": "Yes',
|
||||
", Danswer allows",
|
||||
" customized prompts. This",
|
||||
" feature",
|
||||
" is currently being",
|
||||
" developed and implemente",
|
||||
"d to",
|
||||
" improve",
|
||||
" the accuracy",
|
||||
" of",
|
||||
" Language",
|
||||
" Models (",
|
||||
"LL",
|
||||
"Ms) for",
|
||||
" different",
|
||||
" companies",
|
||||
".",
|
||||
" The custom",
|
||||
"ized prompts feature",
|
||||
" woul",
|
||||
"d allow users to ad",
|
||||
"d person",
|
||||
"alized prom",
|
||||
"pts through",
|
||||
" an",
|
||||
" interface or",
|
||||
" metho",
|
||||
"d,",
|
||||
" which would then be used to",
|
||||
" train",
|
||||
" the LLM.",
|
||||
" This enhancement",
|
||||
" aims to make",
|
||||
" Danswer more",
|
||||
" adaptable to",
|
||||
" different",
|
||||
" business",
|
||||
" contexts",
|
||||
" by",
|
||||
" tail",
|
||||
"oring it",
|
||||
" to the specific language",
|
||||
" an",
|
||||
"d terminology",
|
||||
" used within",
|
||||
" a",
|
||||
" company.",
|
||||
" Additionally",
|
||||
",",
|
||||
" Danswer already",
|
||||
" supports creating",
|
||||
" custom AI",
|
||||
" Assistants with",
|
||||
" different",
|
||||
" prom",
|
||||
"pts and backing",
|
||||
" knowledge",
|
||||
" sets",
|
||||
",",
|
||||
" which",
|
||||
" is",
|
||||
" a form",
|
||||
" of prompt",
|
||||
" customization. However, it",
|
||||
"'s important to nLogging Details LiteLLM-Success Call: Noneote that some",
|
||||
" aspects",
|
||||
" of prompt",
|
||||
" customization,",
|
||||
" such as for",
|
||||
" Sl",
|
||||
"ack",
|
||||
"b",
|
||||
"ots, may",
|
||||
" still",
|
||||
" be in",
|
||||
" development or have",
|
||||
' limitations.",',
|
||||
'\n "quotes": [',
|
||||
'\n "We',
|
||||
" woul",
|
||||
"d like to ad",
|
||||
"d customized prompts for",
|
||||
" different",
|
||||
" companies to improve the accuracy of",
|
||||
" Language",
|
||||
" Model",
|
||||
" (LLM)",
|
||||
'.",\n "A',
|
||||
" new",
|
||||
" feature that",
|
||||
" allows users to add personalize",
|
||||
"d prompts.",
|
||||
" This would involve",
|
||||
" creating",
|
||||
" an interface or method for",
|
||||
" users to input",
|
||||
" their",
|
||||
" own",
|
||||
" prom",
|
||||
"pts,",
|
||||
" which would then be used to",
|
||||
' train the LLM.",',
|
||||
'\n "Create',
|
||||
" custom AI Assistants with",
|
||||
" different prompts and backing knowledge",
|
||||
' sets.",',
|
||||
'\n "This',
|
||||
" PR",
|
||||
" fixes",
|
||||
" https",
|
||||
"://github.com/dan",
|
||||
"swer-ai/dan",
|
||||
"swer/issues/1",
|
||||
"584",
|
||||
" by",
|
||||
" setting",
|
||||
" the system",
|
||||
" default",
|
||||
" prompt for",
|
||||
" sl",
|
||||
"ackbots const",
|
||||
"rained by",
|
||||
" ",
|
||||
"document sets",
|
||||
".",
|
||||
" It",
|
||||
" probably",
|
||||
" isn",
|
||||
"'t ideal",
|
||||
" -",
|
||||
" it",
|
||||
" might",
|
||||
" be pref",
|
||||
"erable to be",
|
||||
" able to select",
|
||||
" a prompt for",
|
||||
" the",
|
||||
" slackbot from",
|
||||
" the",
|
||||
" admin",
|
||||
" panel",
|
||||
" -",
|
||||
" but it sol",
|
||||
"ves the immediate problem",
|
||||
" of",
|
||||
" the slack",
|
||||
" listener",
|
||||
" cr",
|
||||
"ashing when",
|
||||
" configure",
|
||||
"d this",
|
||||
' way."\n ]',
|
||||
"\n}",
|
||||
"",
|
||||
]
|
||||
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens_with_quotes)
|
||||
|
||||
s_json = "".join(tokens_with_quotes)
|
||||
j = json.loads(s_json)
|
||||
expected_answer = j["answer"]
|
||||
assert expected_answer == answer
|
||||
# NOTE: no quotes, since the docs don't match the quotes
|
||||
assert len(quotes) == 0
|
||||
|
||||
|
||||
def test_simple_json_answer() -> None:
|
||||
tokens = [
|
||||
"```",
|
||||
"json",
|
||||
"\n",
|
||||
"{",
|
||||
'"answer": "This is a simple ',
|
||||
"answer.",
|
||||
'",\n"',
|
||||
'quotes": []',
|
||||
"\n}",
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
assert "This is a simple answer." == answer
|
||||
assert len(quotes) == 0
|
||||
|
||||
|
||||
def test_json_answer_with_quotes() -> None:
|
||||
tokens = [
|
||||
"```",
|
||||
"json",
|
||||
"\n",
|
||||
"{",
|
||||
'"answer": "This ',
|
||||
"is a ",
|
||||
"split ",
|
||||
"answer.",
|
||||
'",\n"',
|
||||
'quotes": []',
|
||||
"\n}",
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
assert "This is a split answer." == answer
|
||||
assert len(quotes) == 0
|
||||
|
||||
|
||||
def test_json_answer_with_quotes_one_chunk() -> None:
|
||||
tokens = ['```json\n{"answer": "z",\n"quotes": ["Document"]\n}\n```']
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
assert "z" == answer
|
||||
assert len(quotes) == 1
|
||||
assert quotes[0] == "Document"
|
||||
|
||||
|
||||
def test_json_answer_split_tokens() -> None:
|
||||
tokens = [
|
||||
"```",
|
||||
"json",
|
||||
"\n",
|
||||
"{",
|
||||
'\n"',
|
||||
'answer": "This ',
|
||||
"is a ",
|
||||
"split ",
|
||||
"answer.",
|
||||
'",\n"',
|
||||
'quotes": []',
|
||||
"\n}",
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
assert "This is a split answer." == answer
|
||||
assert len(quotes) == 0
|
||||
|
||||
|
||||
def test_lengthy_prefixed_json_with_quotes() -> None:
|
||||
tokens = [
|
||||
"This is my response in json\n\n",
|
||||
"```",
|
||||
"json",
|
||||
"\n",
|
||||
"{",
|
||||
'"answer": "This is a simple ',
|
||||
"answer.",
|
||||
'",\n"',
|
||||
'quotes": ["Document"]',
|
||||
"\n}",
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
assert "This is a simple answer." == answer
|
||||
assert len(quotes) == 1
|
||||
assert quotes[0] == "Document"
|
||||
|
||||
|
||||
def test_json_with_lengthy_prefix_and_quotes() -> None:
|
||||
tokens = [
|
||||
"*** Based on the provided documents, there does not appear to be any information ",
|
||||
"directly relevant to answering which documents are my favorite. ",
|
||||
"The documents seem to be focused on describing the Danswer product ",
|
||||
"and its features/use cases. Since I do not have personal preferences ",
|
||||
"for documents, I will provide a general response:\n\n",
|
||||
"```",
|
||||
"json",
|
||||
"\n",
|
||||
"{",
|
||||
'"answer": "This is a simple ',
|
||||
"answer.",
|
||||
'",\n"',
|
||||
'quotes": ["Document"]',
|
||||
"\n}",
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
assert "This is a simple answer." == answer
|
||||
assert len(quotes) == 1
|
||||
assert quotes[0] == "Document"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user