mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
38 Commits
remove_see
...
log_error
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16e904a873 | ||
|
|
66f47d294c | ||
|
|
0a685bda7d | ||
|
|
23dc8b5dad | ||
|
|
cd5f2293ad | ||
|
|
6c2269e565 | ||
|
|
46315cddf1 | ||
|
|
5f28a1b0e4 | ||
|
|
9e9b7ed61d | ||
|
|
3fb2bfefec | ||
|
|
7c618c9d17 | ||
|
|
03e2789392 | ||
|
|
2783fa08a3 | ||
|
|
edeaee93a2 | ||
|
|
5385bae100 | ||
|
|
813445ab59 | ||
|
|
af814823c8 | ||
|
|
607f61eaeb | ||
|
|
de66f7adb2 | ||
|
|
3432d932d1 | ||
|
|
9bd0cb9eb5 | ||
|
|
f12eb4a5cf | ||
|
|
16863de0aa | ||
|
|
63d1eefee5 | ||
|
|
e338677896 | ||
|
|
7be80c4af9 | ||
|
|
7f1e4a02bf | ||
|
|
5be7d27285 | ||
|
|
fd84b7a768 | ||
|
|
36941ae663 | ||
|
|
212353ed4a | ||
|
|
eb8708f770 | ||
|
|
ac448956e9 | ||
|
|
634a0b9398 | ||
|
|
09d3e47c03 | ||
|
|
9c0cc94f15 | ||
|
|
07dfde2209 | ||
|
|
28e2b78b2e |
@@ -24,6 +24,8 @@ env:
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -73,6 +73,7 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
@@ -8,6 +8,7 @@ from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
@@ -35,7 +36,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add web ui option to slack config
|
||||
|
||||
Revision ID: 93560ba1b118
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-24 06:36:17.490612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93560ba1b118"
|
||||
down_revision = "6d562f86c78b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add show_continue_in_web_ui with default False to all existing channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
|
||||
WHERE NOT channel_config ? 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove show_continue_in_web_ui from all channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config - 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add auto scroll to user model
|
||||
|
||||
Revision ID: a8c2065484e6
|
||||
Revises: abe7378b8217
|
||||
Create Date: 2024-11-22 17:34:09.690295
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a8c2065484e6"
|
||||
down_revision = "abe7378b8217"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "auto_scroll")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add indexing trigger to cc_pair
|
||||
|
||||
Revision ID: abe7378b8217
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-26 19:09:53.481171
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abe7378b8217"
|
||||
down_revision = "93560ba1b118"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"indexing_trigger",
|
||||
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "indexing_trigger")
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -37,8 +38,15 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -23,7 +23,9 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
return UserPreferences(
|
||||
chosen_assistants=None, default_model=None, auto_scroll=True
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
|
||||
@@ -11,6 +11,7 @@ from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -332,16 +333,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
lock: RedisLock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
except Exception:
|
||||
logger.exception("Failed to release primary worker lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to check if primary worker lock is owned")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -11,6 +11,7 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -116,7 +117,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -227,7 +228,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
lock = worker.primary_worker_lock
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
||||
@@ -2,54 +2,55 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": "check_for_external_group_sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
|
||||
@@ -5,13 +5,13 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -29,7 +29,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -37,7 +37,7 @@ class TaskDependencyError(RuntimeError):
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -60,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
self.app, cc_pair_id, db_session, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
@@ -86,7 +86,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -17,6 +18,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -27,7 +29,7 @@ from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
@@ -81,7 +83,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_doc_permissions_sync",
|
||||
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -138,7 +140,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -162,8 +164,8 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
app.send_task(
|
||||
"connector_permission_sync_generator_task",
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -174,8 +176,8 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=None,
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
@@ -190,7 +192,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_permission_sync_generator_task",
|
||||
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -241,13 +243,17 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(f"No doc sync func found for {source_type}")
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
)
|
||||
payload = redis_connector.permissions.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
@@ -281,7 +287,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="update_external_document_permissions_task",
|
||||
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -16,6 +17,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -24,6 +26,9 @@ from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
@@ -49,7 +54,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
# skip external group sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
@@ -81,7 +86,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_external_group_sync",
|
||||
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -107,7 +112,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
@@ -125,7 +130,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
@@ -156,8 +161,8 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
_ = app.send_task(
|
||||
"connector_external_group_sync_generator_task",
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -166,8 +171,13 @@ def try_creating_permissions_sync_task(
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
# set a basic fence to start
|
||||
redis_connector.external_group_sync.set_fence(True)
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
@@ -182,7 +192,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_external_group_sync_generator_task",
|
||||
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -195,7 +205,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@@ -203,7 +213,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
@@ -228,9 +238,13 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(f"No external group sync func found for {source_type}")
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
@@ -249,7 +263,6 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
@@ -260,6 +273,6 @@ def connector_external_group_sync_generator_task(
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(False)
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
@@ -23,13 +23,16 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
@@ -37,12 +40,13 @@ from danswer.db.index_attempt import delete_index_attempt
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
@@ -153,13 +157,13 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
locked = False
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -172,6 +176,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
locked = True
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
@@ -205,17 +211,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
@@ -231,22 +230,46 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
|
||||
if not _should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
search_settings_primary=search_settings_primary,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
# the indexing trigger is only checked and cleared with the primary search settings
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
|
||||
reindex = True
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing manual trigger detected: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"indexing_mode={cc_pair.indexing_trigger}"
|
||||
)
|
||||
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
False,
|
||||
reindex,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
@@ -256,7 +279,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"Connector indexing queued: "
|
||||
f"index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
@@ -281,7 +304,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -289,13 +311,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
return tasks_created
|
||||
|
||||
@@ -304,6 +327,7 @@ def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
@@ -368,6 +392,11 @@ def _should_index(
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
@@ -458,7 +487,7 @@ def try_creating_indexing_task(
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -495,8 +524,14 @@ def try_creating_indexing_task(
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
@@ -509,6 +544,10 @@ def connector_indexing_proxy_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
@@ -537,25 +576,72 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
finally:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
job.cancel()
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
break
|
||||
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
|
||||
@@ -13,12 +13,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="kombu_message_cleanup_task",
|
||||
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
|
||||
@@ -20,6 +20,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
@@ -75,7 +76,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -184,7 +185,7 @@ def try_creating_prune_generator_task(
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -209,7 +210,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_pruning_generator_task",
|
||||
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
|
||||
@@ -9,6 +9,7 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
@@ -31,7 +32,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
@@ -46,6 +47,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@@ -58,7 +60,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
@@ -79,7 +81,7 @@ logger = setup_logger()
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name="check_for_vespa_sync_task",
|
||||
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -588,7 +590,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncData | None = (
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
@@ -596,9 +598,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
@@ -655,34 +655,42 @@ def monitor_ccpair_indexing_taskset(
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
result_state = result.state
|
||||
task_state = result.state
|
||||
if (
|
||||
result_state in READY_STATES
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"result_state={result_state} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
@@ -692,6 +700,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@@ -699,7 +708,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
@@ -724,7 +733,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
@@ -810,7 +819,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = celery_app
|
||||
app: Celery = celery_app
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
@@ -87,6 +88,10 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
@@ -208,9 +213,7 @@ def _run_indexing(
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_run_indexing: Connector stop signal detected"
|
||||
)
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
@@ -304,26 +307,16 @@ def _run_indexing(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
mark_attempt_canceled(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -335,6 +328,37 @@ def _run_indexing(
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
|
||||
@@ -31,6 +31,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class LlmDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
|
||||
@@ -605,6 +605,7 @@ def stream_chat_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
@@ -308,6 +308,22 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
# The current state of affairs:
|
||||
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
|
||||
# no API retrieves the user's timezone
|
||||
# All data is returned in UTC, so we can't derive the user's timezone from that
|
||||
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -493,10 +509,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
|
||||
@@ -259,6 +259,32 @@ class DanswerCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
class DanswerCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
|
||||
"update_external_document_permissions_task"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
@@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||
)
|
||||
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
|
||||
@@ -11,11 +11,16 @@ Connectors come in 3 different flows:
|
||||
- Load Connector:
|
||||
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll connector:
|
||||
- Poll Connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Slim Connector:
|
||||
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
|
||||
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
|
||||
- This is used by our pruning job which removes old documents from the index.
|
||||
- The optional start and end datetimes can be ignored.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
- Currently not used by the background job, this exists for future design purposes.
|
||||
@@ -26,8 +31,14 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
|
||||
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
|
||||
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
|
||||
|
||||
For implementing a Slim Connector, refer to the comments in this PR:
|
||||
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
|
||||
|
||||
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
|
||||
|
||||
|
||||
#### Implementing the new Connector
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -51,7 +53,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 1000
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@@ -69,6 +71,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
@@ -104,6 +107,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -204,12 +209,14 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
@@ -242,10 +249,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
@@ -301,5 +308,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
@@ -134,6 +134,32 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -306,6 +332,13 @@ def _validate_connector_configuration(
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
|
||||
@@ -12,12 +12,15 @@ from dateutil import parser
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -28,6 +31,8 @@ logger = setup_logger()
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
|
||||
@@ -158,21 +163,26 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
slab_bot_token: str | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
self.slab_bot_token = slab_bot_token
|
||||
self._slab_bot_token: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.slab_bot_token = credentials["slab_bot_token"]
|
||||
self._slab_bot_token = credentials["slab_bot_token"]
|
||||
return None
|
||||
|
||||
@property
|
||||
def slab_bot_token(self) -> str:
|
||||
if self._slab_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Slab")
|
||||
return self._slab_bot_token
|
||||
|
||||
def _iterate_posts(
|
||||
self, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -227,3 +237,21 @@ class SlabConnector(LoadConnector, PollConnector):
|
||||
yield from self._iterate_posts(
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=post_id,
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -135,7 +135,6 @@ class SearchPipeline:
|
||||
|
||||
"""Retrieval and Postprocessing"""
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_chunks(self) -> list[InferenceChunk]:
|
||||
if self._retrieved_chunks is not None:
|
||||
return self._retrieved_chunks
|
||||
@@ -307,7 +306,6 @@ class SearchPipeline:
|
||||
return expanded_inference_sections
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
long sections which could be:
|
||||
@@ -333,7 +331,6 @@ class SearchPipeline:
|
||||
return self._reranked_sections
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def final_context_sections(self) -> list[InferenceSection]:
|
||||
if self._final_context_sections is not None:
|
||||
return self._final_context_sections
|
||||
@@ -342,7 +339,6 @@ class SearchPipeline:
|
||||
return self._final_context_sections
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def section_relevance(self) -> list[SectionRelevancePiece] | None:
|
||||
if self._section_relevance is not None:
|
||||
return self._section_relevance
|
||||
@@ -397,7 +393,6 @@ class SearchPipeline:
|
||||
return self._section_relevance
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def section_relevance_list(self) -> list[bool]:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=self.section_relevance,
|
||||
|
||||
@@ -42,7 +42,6 @@ def _log_top_section_links(search_flow: str, sections: list[InferenceSection]) -
|
||||
logger.debug(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
|
||||
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.title or not chunk.content:
|
||||
@@ -245,7 +244,6 @@ def filter_sections(
|
||||
]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def search_postprocessing(
|
||||
search_query: SearchQuery,
|
||||
retrieved_sections: list[InferenceSection],
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import nltk # type:ignore
|
||||
@@ -86,7 +85,6 @@ def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
return keywords
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
@@ -258,13 +256,7 @@ def retrieve_chunks(
|
||||
(q_copy, document_index, db_session),
|
||||
)
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f"Parallel search execution took {end_time - start_time:.2f} seconds"
|
||||
)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
|
||||
@@ -8,7 +8,6 @@ from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.context.search.models import SavedSearchDocWithContent
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
T = TypeVar(
|
||||
@@ -89,7 +88,6 @@ def drop_llm_indices(
|
||||
return [i for i, val in enumerate(llm_bools) if val]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def inference_section_from_chunks(
|
||||
center_chunk: InferenceChunk,
|
||||
chunks: list[InferenceChunk],
|
||||
|
||||
@@ -18,20 +18,30 @@ from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.formatting import format_slack_message
|
||||
from danswer.danswerbot.slack.icons import source_to_github_img_link
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
|
||||
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.db.chat import get_chat_session_by_message_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
@@ -101,12 +111,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
|
||||
return chunks
|
||||
|
||||
|
||||
def clean_markdown_link_text(text: str) -> str:
|
||||
def _clean_markdown_link_text(text: str) -> str:
|
||||
# Remove any newlines within the text
|
||||
return text.replace("\n", " ").strip()
|
||||
|
||||
|
||||
def build_qa_feedback_block(
|
||||
def _build_qa_feedback_block(
|
||||
message_id: int, feedback_reminder_id: str | None = None
|
||||
) -> Block:
|
||||
return ActionsBlock(
|
||||
@@ -115,7 +125,6 @@ def build_qa_feedback_block(
|
||||
ButtonElement(
|
||||
action_id=LIKE_BLOCK_ACTION_ID,
|
||||
text="👍 Helpful",
|
||||
style="primary",
|
||||
value=feedback_reminder_id,
|
||||
),
|
||||
ButtonElement(
|
||||
@@ -155,7 +164,7 @@ def get_document_feedback_blocks() -> Block:
|
||||
)
|
||||
|
||||
|
||||
def build_doc_feedback_block(
|
||||
def _build_doc_feedback_block(
|
||||
message_id: int,
|
||||
document_id: str,
|
||||
document_rank: int,
|
||||
@@ -182,7 +191,7 @@ def get_restate_blocks(
|
||||
]
|
||||
|
||||
|
||||
def build_documents_blocks(
|
||||
def _build_documents_blocks(
|
||||
documents: list[SavedSearchDoc],
|
||||
message_id: int | None,
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
@@ -223,7 +232,7 @@ def build_documents_blocks(
|
||||
|
||||
feedback: ButtonElement | dict = {}
|
||||
if message_id is not None:
|
||||
feedback = build_doc_feedback_block(
|
||||
feedback = _build_doc_feedback_block(
|
||||
message_id=message_id,
|
||||
document_id=d.document_id,
|
||||
document_rank=rank,
|
||||
@@ -241,7 +250,7 @@ def build_documents_blocks(
|
||||
return section_blocks
|
||||
|
||||
|
||||
def build_sources_blocks(
|
||||
def _build_sources_blocks(
|
||||
cited_documents: list[tuple[int, SavedSearchDoc]],
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
) -> list[Block]:
|
||||
@@ -286,7 +295,7 @@ def build_sources_blocks(
|
||||
+ ([days_ago_str] if days_ago_str else [])
|
||||
)
|
||||
|
||||
document_title = clean_markdown_link_text(doc_sem_id)
|
||||
document_title = _clean_markdown_link_text(doc_sem_id)
|
||||
img_link = source_to_github_img_link(d.source_type)
|
||||
|
||||
section_blocks.append(
|
||||
@@ -317,7 +326,50 @@ def build_sources_blocks(
|
||||
return section_blocks
|
||||
|
||||
|
||||
def build_quotes_block(
|
||||
def _priority_ordered_documents_blocks(
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
if not priority_ordered_docs:
|
||||
return []
|
||||
|
||||
document_blocks = _build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
if document_blocks:
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
return document_blocks
|
||||
|
||||
|
||||
def _build_citations_blocks(
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = _build_sources_blocks(cited_documents=cited_docs)
|
||||
return citations_block
|
||||
|
||||
|
||||
def _build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
@@ -359,58 +411,70 @@ def build_quotes_block(
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
quotes: list[DanswerQuote] | None,
|
||||
source_filters: list[DocumentSource] | None,
|
||||
time_cutoff: datetime | None,
|
||||
favor_recent: bool,
|
||||
def _build_qa_response_blocks(
|
||||
answer: OneShotQAResponse,
|
||||
skip_quotes: bool = False,
|
||||
process_message_for_citations: bool = False,
|
||||
skip_ai_feedback: bool = False,
|
||||
feedback_reminder_id: str | None = None,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
quotes = answer.quotes.quotes if answer.quotes else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
filter_block: Block | None = None
|
||||
if time_cutoff or favor_recent or source_filters:
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
or retrieval_info.recency_bias_multiplier > 1
|
||||
or retrieval_info.applied_source_filters
|
||||
):
|
||||
filter_text = "Filters: "
|
||||
if source_filters:
|
||||
sources_str = ", ".join([s.value for s in source_filters])
|
||||
if retrieval_info.applied_source_filters:
|
||||
sources_str = ", ".join(
|
||||
[s.value for s in retrieval_info.applied_source_filters]
|
||||
)
|
||||
filter_text += f"`Sources in [{sources_str}]`"
|
||||
if time_cutoff or favor_recent:
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
or retrieval_info.recency_bias_multiplier > 1
|
||||
):
|
||||
filter_text += " and "
|
||||
if time_cutoff is not None:
|
||||
time_str = time_cutoff.strftime("%b %d, %Y")
|
||||
if retrieval_info.applied_time_cutoff is not None:
|
||||
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
|
||||
filter_text += f"`Docs Updated >= {time_str}` "
|
||||
if favor_recent:
|
||||
if time_cutoff is not None:
|
||||
if retrieval_info.recency_bias_multiplier > 1:
|
||||
if retrieval_info.applied_time_cutoff is not None:
|
||||
filter_text += "+ "
|
||||
filter_text += "`Prioritize Recently Updated Docs`"
|
||||
|
||||
filter_block = SectionBlock(text=f"_{filter_text}_")
|
||||
|
||||
if not answer:
|
||||
if not formatted_answer:
|
||||
answer_blocks = [
|
||||
SectionBlock(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
]
|
||||
else:
|
||||
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
|
||||
answer_processed = decode_escapes(
|
||||
remove_slack_text_interactions(formatted_answer)
|
||||
)
|
||||
if process_message_for_citations:
|
||||
answer_processed = _process_citations_for_slack(answer_processed)
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
if quotes:
|
||||
quotes_blocks = build_quotes_block(quotes)
|
||||
quotes_blocks = _build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `build_quotes_block()` did not give back any blocks
|
||||
# if no quotes OR `_build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
@@ -425,20 +489,37 @@ def build_qa_response_blocks(
|
||||
|
||||
response_blocks.extend(answer_blocks)
|
||||
|
||||
if message_id is not None and not skip_ai_feedback:
|
||||
response_blocks.append(
|
||||
build_qa_feedback_block(
|
||||
message_id=message_id, feedback_reminder_id=feedback_reminder_id
|
||||
)
|
||||
)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
|
||||
return response_blocks
|
||||
|
||||
|
||||
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
)
|
||||
return ActionsBlock(
|
||||
block_id=build_continue_in_web_ui_id(message_id),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
|
||||
text="Continue Chat in Danswer!",
|
||||
style="primary",
|
||||
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
|
||||
return ActionsBlock(
|
||||
block_id=build_feedback_id(message_id) if message_id is not None else None,
|
||||
elements=[
|
||||
@@ -483,3 +564,77 @@ def build_follow_up_resolved_blocks(
|
||||
]
|
||||
)
|
||||
return [text_block, button_block]
|
||||
|
||||
|
||||
def build_slack_response_blocks(
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
answer: OneShotQAResponse,
|
||||
persona: Persona | None,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
skip_ai_feedback: bool = False,
|
||||
) -> list[Block]:
|
||||
"""
|
||||
This function is a top level function that builds all the blocks for the Slack response.
|
||||
It also handles combining all the blocks together.
|
||||
"""
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
)
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
web_follow_up_block = []
|
||||
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
|
||||
web_follow_up_block.append(
|
||||
_build_continue_in_web_ui_block(
|
||||
tenant_id=tenant_id,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
follow_up_block = []
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
follow_up_block.append(
|
||||
_build_follow_up_block(message_id=answer.chat_message_id)
|
||||
)
|
||||
|
||||
ai_feedback_block = []
|
||||
if answer.chat_message_id is not None and not skip_ai_feedback:
|
||||
ai_feedback_block.append(
|
||||
_build_qa_feedback_block(
|
||||
message_id=answer.chat_message_id,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if use_citations:
|
||||
# if citations are enabled, only show cited documents
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block
|
||||
+ answer_blocks
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
)
|
||||
return all_blocks
|
||||
|
||||
@@ -2,6 +2,7 @@ from enum import Enum
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
|
||||
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
|
||||
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
|
||||
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
|
||||
|
||||
@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import decompose_action_id
|
||||
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_feedback_visibility
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
@@ -267,7 +267,7 @@ def handle_followup_button(
|
||||
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
|
||||
remaining = None
|
||||
if tag_names:
|
||||
tag_ids, remaining = fetch_user_ids_from_emails(
|
||||
tag_ids, remaining = fetch_slack_user_ids_from_emails(
|
||||
tag_names, client.web_client
|
||||
)
|
||||
if remaining:
|
||||
|
||||
@@ -13,7 +13,7 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
handle_standard_answers,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
@@ -184,7 +184,7 @@ def handle_message(
|
||||
send_to: list[str] | None = None
|
||||
missing_users: list[str] | None = None
|
||||
if respond_member_group_list:
|
||||
send_to, missing_ids = fetch_user_ids_from_emails(
|
||||
send_to, missing_ids = fetch_slack_user_ids_from_emails(
|
||||
respond_member_group_list, client
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
@@ -25,12 +24,7 @@ from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.formatting import format_slack_message
|
||||
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
@@ -411,62 +405,16 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=formatted_answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
all_blocks = build_slack_response_blocks(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
persona=persona,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
|
||||
@@ -3,9 +3,9 @@ import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
@@ -216,6 +216,13 @@ def build_feedback_id(
|
||||
return unique_prefix + ID_SEPARATOR + feedback_id
|
||||
|
||||
|
||||
def build_continue_in_web_ui_id(
|
||||
message_id: int,
|
||||
) -> str:
|
||||
unique_prefix = str(uuid.uuid4())[:10]
|
||||
return unique_prefix + ID_SEPARATOR + str(message_id)
|
||||
|
||||
|
||||
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
|
||||
"""Decompose into query_id, document_id, document_rank, see above function"""
|
||||
try:
|
||||
@@ -313,7 +320,7 @@ def get_channel_name_from_id(
|
||||
raise e
|
||||
|
||||
|
||||
def fetch_user_ids_from_emails(
|
||||
def fetch_slack_user_ids_from_emails(
|
||||
user_emails: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
user_ids: list[str] = []
|
||||
@@ -522,7 +529,7 @@ class SlackRateLimiter:
|
||||
self.last_reset_time = time.time()
|
||||
|
||||
def notify(
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: str | None
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
@@ -30,6 +31,7 @@ from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_best_persona_id_for_user
|
||||
from danswer.db.pg_file_store import delete_lobj_by_name
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
@@ -250,6 +252,50 @@ def create_chat_session(
|
||||
return chat_session
|
||||
|
||||
|
||||
def duplicate_chat_session_for_user_from_slack(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
chat_session_id: UUID,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
This takes a chat session id for a session in Slack and:
|
||||
- Creates a new chat session in the DB
|
||||
- Tries to copy the persona from the original chat session
|
||||
(if it is available to the user clicking the button)
|
||||
- Sets the user to the given user (if provided)
|
||||
"""
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None, # Ignore user permissions for this
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_session:
|
||||
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
|
||||
|
||||
# This enforces permissions and sets a default
|
||||
new_persona_id = get_best_persona_id_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona_id=chat_session.persona_id,
|
||||
)
|
||||
|
||||
return create_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user.id if user else None,
|
||||
persona_id=new_persona_id,
|
||||
# Set this to empty string so the frontend will force a rename
|
||||
description="",
|
||||
llm_override=chat_session.llm_override,
|
||||
prompt_override=chat_session.prompt_override,
|
||||
# Chat sessions from Slack should put people in the chat UI, not the search
|
||||
one_shot=False,
|
||||
# Chat is in UI now so this is false
|
||||
danswerbot_flow=False,
|
||||
# Maybe we want this in the future to track if it was created from Slack
|
||||
slack_thread_id=None,
|
||||
)
|
||||
|
||||
|
||||
def update_chat_session(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
@@ -336,6 +382,28 @@ def get_chat_message(
|
||||
return chat_message
|
||||
|
||||
|
||||
def get_chat_session_by_message_id(
|
||||
db_session: Session,
|
||||
message_id: int,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Should only be used for Slack
|
||||
Get the chat session associated with a specific message ID
|
||||
Note: this ignores permission checks.
|
||||
"""
|
||||
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_message = result.scalar_one_or_none()
|
||||
|
||||
if chat_message is None:
|
||||
raise ValueError(
|
||||
f"Unable to find chat session associated with message ID: {message_id}"
|
||||
)
|
||||
|
||||
return chat_message.chat_session
|
||||
|
||||
|
||||
def get_chat_messages_by_sessions(
|
||||
chat_session_ids: list[UUID],
|
||||
user_id: UUID | None,
|
||||
@@ -355,6 +423,44 @@ def get_chat_messages_by_sessions(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def add_chats_to_session_from_slack_thread(
|
||||
db_session: Session,
|
||||
slack_chat_session_id: UUID,
|
||||
new_chat_session_id: UUID,
|
||||
) -> None:
|
||||
new_root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
for chat_message in get_chat_messages_by_sessions(
|
||||
chat_session_ids=[slack_chat_session_id],
|
||||
user_id=None, # Ignore user permissions for this
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
):
|
||||
if chat_message.message_type == MessageType.SYSTEM:
|
||||
continue
|
||||
# Duplicate the message
|
||||
new_root_message = create_new_chat_message(
|
||||
db_session=db_session,
|
||||
chat_session_id=new_chat_session_id,
|
||||
parent_message=new_root_message,
|
||||
message=chat_message.message,
|
||||
files=chat_message.files,
|
||||
rephrased_query=chat_message.rephrased_query,
|
||||
error=chat_message.error,
|
||||
citations=chat_message.citations,
|
||||
reference_docs=chat_message.search_docs,
|
||||
tool_call=chat_message.tool_call,
|
||||
prompt_id=chat_message.prompt_id,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
|
||||
def get_search_docs_for_chat_message(
|
||||
chat_message_id: int, db_session: Session
|
||||
) -> list[SearchDoc]:
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -311,3 +312,25 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
|
||||
# If this changes, we need to update this function.
|
||||
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_ccpair_with_indexing_trigger(
|
||||
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
|
||||
) -> None:
|
||||
"""indexing_mode sets a field which will be picked up by a background task
|
||||
to trigger indexing. Set to None to disable the trigger."""
|
||||
try:
|
||||
cc_pair = db_session.execute(
|
||||
select(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
cc_pair.indexing_trigger = indexing_mode
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
@@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
def _relate_groups_to_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
user_group_ids: list[int],
|
||||
user_group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
if not user_group_ids:
|
||||
return
|
||||
|
||||
for group_id in user_group_ids:
|
||||
db_session.add(
|
||||
UserGroup__ConnectorCredentialPair(
|
||||
@@ -402,12 +405,11 @@ def add_credential_to_connector(
|
||||
db_session.flush() # make sure the association has an id
|
||||
db_session.refresh(association)
|
||||
|
||||
if groups and access_type != AccessType.SYNC:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -19,6 +19,11 @@ class IndexingStatus(str, PyEnum):
|
||||
return self in terminal_states
|
||||
|
||||
|
||||
class IndexingMode(str, PyEnum):
|
||||
UPDATE = "update"
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
# these may differ in the future, which is why we're okay with this duplication
|
||||
class DeletionStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
|
||||
@@ -42,7 +42,7 @@ from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import AccessType, IndexingMode
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
@@ -126,6 +126,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
@@ -438,6 +439,10 @@ class ConnectorCredentialPair(Base):
|
||||
|
||||
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
indexing_trigger: Mapped[IndexingMode | None] = mapped_column(
|
||||
Enum(IndexingMode, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
connector: Mapped["Connector"] = relationship(
|
||||
"Connector", back_populates="credentials"
|
||||
)
|
||||
@@ -1480,6 +1485,7 @@ class ChannelConfig(TypedDict):
|
||||
# If None then no follow up
|
||||
# If empty list, follow up with no tags
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
show_continue_in_web_ui: NotRequired[bool] # defaults to False
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
|
||||
@@ -113,6 +113,31 @@ def fetch_persona_by_id(
|
||||
return persona
|
||||
|
||||
|
||||
def get_best_persona_id_for_user(
|
||||
db_session: Session, user: User | None, persona_id: int | None = None
|
||||
) -> int | None:
|
||||
if persona_id is not None:
|
||||
stmt = select(Persona).where(Persona.id == persona_id).distinct()
|
||||
stmt = _add_user_filters(
|
||||
stmt=stmt,
|
||||
user=user,
|
||||
# We don't want to filter by editable here, we just want to see if the
|
||||
# persona is usable by the user
|
||||
get_editable=False,
|
||||
)
|
||||
persona = db_session.scalars(stmt).one_or_none()
|
||||
if persona:
|
||||
return persona.id
|
||||
|
||||
# If the persona is not found, or the slack bot is using doc sets instead of personas,
|
||||
# we need to find the best persona for the user
|
||||
# This is the persona with the highest display priority that the user has access to
|
||||
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
|
||||
persona = db_session.scalars(stmt).one_or_none()
|
||||
return persona.id if persona else None
|
||||
|
||||
|
||||
def _get_persona_by_name(
|
||||
persona_name: str, user: User | None, db_session: Session
|
||||
) -> Persona | None:
|
||||
@@ -160,7 +185,7 @@ def create_update_persona(
|
||||
"persona_id": persona_id,
|
||||
"user": user,
|
||||
"db_session": db_session,
|
||||
**create_persona_request.dict(exclude={"users", "groups"}),
|
||||
**create_persona_request.model_dump(exclude={"users", "groups"}),
|
||||
}
|
||||
|
||||
persona = upsert_persona(**persona_data)
|
||||
@@ -390,9 +415,6 @@ def upsert_prompt(
|
||||
return prompt
|
||||
|
||||
|
||||
# NOTE: This operation cannot update persona configuration options that
|
||||
# are core to the persona, such as its display priority and
|
||||
# whether or not the assistant is a built-in / default assistant
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -424,6 +446,12 @@ def upsert_persona(
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
"""
|
||||
NOTE: This operation cannot update persona configuration options that
|
||||
are core to the persona, such as its display priority and
|
||||
whether or not the assistant is a built-in / default assistant
|
||||
"""
|
||||
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
else:
|
||||
@@ -461,6 +489,8 @@ def upsert_persona(
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
# Built-in personas can only be updated through YAML configuration.
|
||||
# This ensures that core system personas are not modified unintentionally.
|
||||
if persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
@@ -469,6 +499,9 @@ def upsert_persona(
|
||||
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
persona.name = name
|
||||
persona.description = description
|
||||
persona.num_chunks = num_chunks
|
||||
@@ -733,6 +766,8 @@ def get_prompt_by_name(
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Prompt.user_id == user.id)
|
||||
|
||||
# Order by ID to ensure consistent result when multiple prompts exist
|
||||
stmt = stmt.order_by(Prompt.id).limit(1)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
@@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
"""Returns active search settings. The first entry will always be the current search
|
||||
settings. If there are new search settings that are being migrated to, those will be
|
||||
the second entry."""
|
||||
search_settings_list: list[SearchSettings] = []
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings_list.append(primary_search_settings)
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings_list.append(secondary_search_settings)
|
||||
|
||||
return search_settings_list
|
||||
|
||||
|
||||
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
query = select(SearchSettings).order_by(SearchSettings.id.desc())
|
||||
result = db_session.execute(query)
|
||||
|
||||
@@ -48,7 +48,6 @@ from danswer.document_index.vespa_constants import TITLE
|
||||
from danswer.document_index.vespa_constants import YQL_BASE
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -147,7 +146,6 @@ def _vespa_hit_to_inference_chunk(
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_chunks_via_visit_api(
|
||||
chunk_request: VespaChunkRequest,
|
||||
index_name: str,
|
||||
@@ -234,7 +232,6 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
|
||||
@retry(tries=10, delay=1, backoff=2)
|
||||
@log_function_time(print_only=True)
|
||||
def get_all_vespa_ids_for_document_id(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
@@ -251,7 +248,6 @@ def get_all_vespa_ids_for_document_id(
|
||||
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def parallel_visit_api_retrieval(
|
||||
index_name: str,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
@@ -266,12 +262,9 @@ def parallel_visit_api_retrieval(
|
||||
for chunk_request in chunk_requests
|
||||
]
|
||||
|
||||
start_time = datetime.now()
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
duration = datetime.now() - start_time
|
||||
print(f"Parallel visit API retrieval took {duration.total_seconds():.2f} seconds")
|
||||
|
||||
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
||||
vespa_chunk_sets = [res for res in parallel_results if res]
|
||||
@@ -289,11 +282,9 @@ def parallel_visit_api_retrieval(
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
@log_function_time(print_only=True)
|
||||
def query_vespa(
|
||||
query_params: Mapping[str, str | int | float]
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
print(f"query_params: {query_params}")
|
||||
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
||||
raise ValueError("No/empty query received")
|
||||
|
||||
@@ -349,7 +340,6 @@ def query_vespa(
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_chunks_via_batch_search(
|
||||
index_name: str,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
@@ -384,7 +374,6 @@ def _get_chunks_via_batch_search(
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def batch_search_api_retrieval(
|
||||
index_name: str,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
|
||||
@@ -72,7 +72,6 @@ from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@@ -661,7 +660,6 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
return total_chunks_deleted
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
@@ -683,7 +681,6 @@ class VespaIndex(DocumentIndex):
|
||||
get_large_chunks=get_large_chunks,
|
||||
)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@@ -295,7 +295,7 @@ def pptx_to_text(file: IO[Any]) -> str:
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any]) -> str:
|
||||
workbook = openpyxl.load_workbook(file)
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_string = "\n".join(
|
||||
|
||||
@@ -59,6 +59,12 @@ class FileStore(ABC):
|
||||
Contents of the file and metadata dict
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read_file_record(self, file_name: str) -> PGFileStore:
|
||||
"""
|
||||
Read the file record by the name
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -21,7 +21,6 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -44,7 +43,6 @@ class ChunkRange(BaseModel):
|
||||
end: int
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
|
||||
"""
|
||||
This acts on a single document to merge the overlapping ranges of chunks
|
||||
@@ -302,7 +300,6 @@ def prune_sections(
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
@@ -330,7 +327,6 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
|
||||
doc_order: dict[str, int] = {}
|
||||
|
||||
@@ -67,9 +67,9 @@ class CitationProcessor:
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
@@ -77,13 +77,15 @@ class CitationProcessor:
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
result = "" # Initialize result here
|
||||
result = ""
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(citation.group(1))
|
||||
numerical_value = int(
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
@@ -131,14 +133,6 @@ class CitationProcessor:
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
@@ -149,6 +143,7 @@ class CitationProcessor:
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
|
||||
@@ -26,7 +26,9 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
||||
from danswer.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from danswer.llm.interfaces import LLM
|
||||
@@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant":
|
||||
# NOTE: if tool calls are present, then it's an assistant.
|
||||
# In Ollama, the role will be None for tool-calls
|
||||
elif role == "assistant" or tool_calls:
|
||||
if tool_calls:
|
||||
tool_call = tool_calls[0]
|
||||
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
||||
@@ -236,6 +240,7 @@ class DefaultMultiLLM(LLM):
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
):
|
||||
self._timeout = timeout
|
||||
@@ -268,7 +273,7 @@ class DefaultMultiLLM(LLM):
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
model_kwargs = model_kwargs or {}
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
@@ -13,6 +16,15 @@ from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
|
||||
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
|
||||
TODO: allow model-specific values to be configured via the UI.
|
||||
"""
|
||||
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
llms: tuple[LLM, LLM],
|
||||
) -> LLM:
|
||||
@@ -59,6 +71,7 @@ def get_llms_for_persona(
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
@@ -116,11 +129,13 @@ def get_llm(
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
temperature: float | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
@@ -132,5 +147,6 @@ def get_llm(
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=build_llm_extra_headers(additional_headers),
|
||||
model_kwargs=_build_extra_model_kwargs(provider),
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
@@ -385,6 +386,62 @@ def test_llm(llm: LLM) -> str | None:
|
||||
return error_msg
|
||||
|
||||
|
||||
def get_model_map() -> dict:
|
||||
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
||||
|
||||
# NOTE: we could add additional models here in the future,
|
||||
# but for now there is no point. Ollama allows the user to
|
||||
# to specify their desired max context window, and it's
|
||||
# unlikely to be standard across users even for the same model
|
||||
# (it heavily depends on their hardware). For now, we'll just
|
||||
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
|
||||
# for model_name in [
|
||||
# "llama3.2",
|
||||
# "llama3.2:1b",
|
||||
# "llama3.2:3b",
|
||||
# "llama3.2:11b",
|
||||
# "llama3.2:90b",
|
||||
# ]:
|
||||
# starting_map[f"ollama/{model_name}"] = {
|
||||
# "max_tokens": 128000,
|
||||
# "max_input_tokens": 128000,
|
||||
# "max_output_tokens": 128000,
|
||||
# }
|
||||
|
||||
return starting_map
|
||||
|
||||
|
||||
def _strip_extra_provider_from_model_name(model_name: str) -> str:
|
||||
return model_name.split("/")[1] if "/" in model_name else model_name
|
||||
|
||||
|
||||
def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(
|
||||
model_map: dict, provider: str, model_names: list[str | None]
|
||||
) -> dict | None:
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
# First try all model names with provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(f"{provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {provider}/{model_name}")
|
||||
return model_obj
|
||||
|
||||
# Then try all model names without provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
return model_obj
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
@@ -397,22 +454,22 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_name_split = model_name.split("/")
|
||||
if len(model_name_split) > 1:
|
||||
model_obj = model_map.get(model_name_split[1])
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name_split[1]}")
|
||||
|
||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||
model_name
|
||||
)
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
[
|
||||
model_name,
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
extra_provider_stripped_model_name,
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||
],
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
@@ -488,7 +545,7 @@ def get_max_input_tokens(
|
||||
# `model_cost` dict is a named public interface:
|
||||
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
||||
# model_map is litellm.model_cost
|
||||
litellm_model_map = litellm.model_cost
|
||||
litellm_model_map = get_model_map()
|
||||
|
||||
input_toks = (
|
||||
get_llm_max_tokens(
|
||||
|
||||
@@ -26,6 +26,7 @@ from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import create_danswer_oauth_router
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
@@ -44,6 +45,7 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.server.api_key.api import router as api_key_router
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
@@ -280,6 +282,7 @@ def get_application() -> FastAPI:
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
@@ -323,7 +326,7 @@ def get_application() -> FastAPI:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_oauth_router(
|
||||
create_danswer_oauth_router(
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
class ModelServerRateLimitError(Exception):
|
||||
"""
|
||||
Exception raised for rate limiting errors from the model server.
|
||||
"""
|
||||
@@ -6,6 +6,9 @@ from typing import Any
|
||||
|
||||
import requests
|
||||
from httpx import HTTPError
|
||||
from requests import JSONDecodeError
|
||||
from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
@@ -16,6 +19,9 @@ from danswer.configs.model_configs import (
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.exceptions import (
|
||||
ModelServerRateLimitError,
|
||||
)
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -99,28 +105,43 @@ class EmbeddingModel:
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||
def _make_request() -> EmbedResponse:
|
||||
def _make_request() -> Response:
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.model_dump()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# signify that this is a rate limit error
|
||||
if response.status_code == 429:
|
||||
raise ModelServerRateLimitError(response.text)
|
||||
|
||||
return EmbedResponse(**response.json())
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
|
||||
final_make_request_func = _make_request
|
||||
|
||||
# if the text type is a passage, add some default
|
||||
# retries + handling for rate limiting
|
||||
if embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
return retry(tries=3, delay=5)(_make_request)()
|
||||
else:
|
||||
return _make_request()
|
||||
final_make_request_func = retry(
|
||||
tries=3,
|
||||
delay=5,
|
||||
exceptions=(RequestException, ValueError, JSONDecodeError),
|
||||
)(final_make_request_func)
|
||||
# use 10 second delay as per Azure suggestion
|
||||
final_make_request_func = retry(
|
||||
tries=10, delay=10, exceptions=ModelServerRateLimitError
|
||||
)(final_make_request_func)
|
||||
|
||||
try:
|
||||
response = final_make_request_func()
|
||||
return EmbedResponse(**response.json())
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
|
||||
def _batch_encode_texts(
|
||||
self,
|
||||
|
||||
@@ -131,7 +131,7 @@ def _try_initialize_tokenizer(
|
||||
return tokenizer
|
||||
except Exception as hf_error:
|
||||
logger.warning(
|
||||
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
|
||||
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
|
||||
)
|
||||
|
||||
# If both initializations fail, return None
|
||||
|
||||
@@ -47,6 +47,7 @@ from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.one_shot_answer.qa_utils import slackify_message_thread
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@@ -194,13 +195,22 @@ def stream_answer_objects(
|
||||
)
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
user_message_str = query_msg.message
|
||||
# For this endpoint, we only save one user message to the chat session
|
||||
# However, for slackbot, we want to include the history of the entire thread
|
||||
if danswerbot_flow:
|
||||
# Right now, we only support bringing over citations and search docs
|
||||
# from the last message in the thread, not the entire thread
|
||||
# in the future, we may want to retrieve the entire thread
|
||||
user_message_str = slackify_message_thread(query_req.messages)
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=query_msg.message,
|
||||
token_count=len(llm_tokenizer.encode(query_msg.message)),
|
||||
message=user_message_str,
|
||||
token_count=len(llm_tokenizer.encode(user_message_str)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
|
||||
@@ -51,3 +51,31 @@ def combine_message_thread(
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def slackify_message(message: ThreadMessage) -> str:
|
||||
if message.role != MessageType.USER:
|
||||
return message.message
|
||||
|
||||
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"DanswerBot said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@@ -31,6 +34,44 @@ class RedisConnector:
|
||||
self.tenant_id, self.id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
def wait_for_indexing_termination(
|
||||
self,
|
||||
search_settings_list: list[SearchSettings],
|
||||
timeout: float = 15.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if all indexing for the given redis connector is finished within the given timeout.
|
||||
Returns False if the timeout is exceeded
|
||||
|
||||
This check does not guarantee that current indexings being terminated
|
||||
won't get restarted midflight
|
||||
"""
|
||||
|
||||
finished = False
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
while True:
|
||||
still_indexing = False
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = self.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
still_indexing = True
|
||||
break
|
||||
|
||||
if not still_indexing:
|
||||
finished = True
|
||||
break
|
||||
|
||||
now = time.monotonic()
|
||||
if now - start > timeout:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
return finished
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
@@ -105,7 +106,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.models import Document as DbDocument
|
||||
@@ -114,7 +115,7 @@ class RedisConnectorDelete:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -12,10 +12,12 @@ from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncData(BaseModel):
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorPermissionSync:
|
||||
@@ -78,14 +80,14 @@ class RedisConnectorPermissionSync:
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPermissionSyncData | None:
|
||||
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorPermissionSyncData.model_validate_json(
|
||||
payload = RedisConnectorPermissionSyncPayload.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
@@ -93,7 +95,7 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorPermissionSyncData | None,
|
||||
payload: RedisConnectorPermissionSyncPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
@@ -148,7 +150,7 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"update_external_document_permissions_task",
|
||||
DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
kwargs=dict(
|
||||
tenant_id=self.tenant_id,
|
||||
serialized_doc_external_access=doc_perm.to_dict(),
|
||||
@@ -162,6 +164,12 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSync:
|
||||
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
@@ -68,12 +75,29 @@ class RedisConnectorExternalGroupSync:
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
@property
|
||||
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorExternalGroupSyncPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -29,6 +29,8 @@ class RedisConnectorIndex:
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
@@ -51,6 +53,7 @@ class RedisConnectorIndex:
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@@ -92,6 +95,18 @@ class RedisConnectorIndex:
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
def terminating(self, celery_task_id: str) -> bool:
|
||||
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_terminate(self, celery_task_id: str) -> None:
|
||||
"""This sets a signal. It does not block!"""
|
||||
# We shouldn't need very long to terminate the spawned task.
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
|
||||
|
||||
@@ -134,7 +135,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
@@ -76,7 +77,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -89,7 +90,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -15,7 +15,6 @@ from danswer.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import count_punctuation
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -49,7 +48,6 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
return model_output
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def multilingual_query_expansion(
|
||||
query: str,
|
||||
expansion_languages: list[str],
|
||||
|
||||
@@ -81,6 +81,7 @@ def load_personas_from_yaml(
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
|
||||
@@ -5,7 +5,7 @@ personas:
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Knowledge"
|
||||
name: "Search"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -37,7 +38,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -158,7 +161,19 @@ def update_cc_pair_status(
|
||||
status_update_request: CCStatusUpdateRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""This method may wait up to 30 seconds if pausing the connector due to the need to
|
||||
terminate tasks in progress. Tasks are not guaranteed to terminate within the
|
||||
timeout.
|
||||
|
||||
Returns HTTPStatus.OK if everything finished.
|
||||
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
|
||||
did not finish within the timeout.
|
||||
"""
|
||||
WAIT_TIMEOUT = 15.0
|
||||
still_terminating = False
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@@ -173,10 +188,76 @@ def update_cc_pair_status(
|
||||
)
|
||||
|
||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
try:
|
||||
redis_connector.stop.set_fence(True)
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Wait for indexing soft termination timed out. "
|
||||
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
|
||||
)
|
||||
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
if not redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
index_payload = redis_connector_index.payload
|
||||
if not index_payload:
|
||||
continue
|
||||
|
||||
if not index_payload.celery_task_id:
|
||||
continue
|
||||
|
||||
# Revoke the task to prevent it from running
|
||||
primary_app.control.revoke(index_payload.celery_task_id)
|
||||
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
redis_connector_index.set_terminate(index_payload.celery_task_id)
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
|
||||
)
|
||||
still_terminating = True
|
||||
break
|
||||
finally:
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -185,6 +266,18 @@ def update_cc_pair_status(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if still_terminating:
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.ACCEPTED,
|
||||
content={
|
||||
"message": "Request accepted, background task termination still in progress"
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/name")
|
||||
def update_cc_pair_name(
|
||||
@@ -267,9 +360,9 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"Pruning cc_pair: cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
|
||||
@@ -17,9 +17,10 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.connectors.google_utils.google_auth import (
|
||||
@@ -59,6 +60,7 @@ from danswer.db.connector import delete_connector
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector import get_connector_credential_ids
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from danswer.db.connector import update_connector
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
@@ -74,6 +76,7 @@ from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_latest_index_attempts
|
||||
@@ -86,7 +89,6 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import AuthStatus
|
||||
from danswer.server.documents.models import AuthUrl
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -792,12 +794,10 @@ def connector_run_once(
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
) -> StatusResponse[int]:
|
||||
"""Used to trigger indexing on a set of cc_pairs associated with a
|
||||
single connector."""
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
connector_id = run_info.connector_id
|
||||
specified_credential_ids = run_info.credential_ids
|
||||
|
||||
@@ -843,54 +843,41 @@ def connector_run_once(
|
||||
)
|
||||
]
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(connector_id, credential_id, db_session)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
||||
index_attempt_ids = []
|
||||
num_triggers = 0
|
||||
for cc_pair in connector_credential_pairs:
|
||||
if cc_pair is not None:
|
||||
attempt_id = try_creating_indexing_task(
|
||||
primary_app,
|
||||
cc_pair,
|
||||
search_settings,
|
||||
run_info.from_beginning,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
indexing_mode = IndexingMode.UPDATE
|
||||
if run_info.from_beginning:
|
||||
indexing_mode = IndexingMode.REINDEX
|
||||
|
||||
mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session)
|
||||
num_triggers += 1
|
||||
|
||||
logger.info(
|
||||
f"connector_run_once - marking cc_pair with indexing trigger: "
|
||||
f"connector={run_info.connector_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"indexing_trigger={indexing_mode}"
|
||||
)
|
||||
if attempt_id:
|
||||
logger.info(
|
||||
f"connector_run_once - try_creating_indexing_task succeeded: "
|
||||
f"connector={run_info.connector_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"attempt={attempt_id} "
|
||||
)
|
||||
index_attempt_ids.append(attempt_id)
|
||||
else:
|
||||
logger.info(
|
||||
f"connector_run_once - try_creating_indexing_task failed: "
|
||||
f"connector={run_info.connector_id} "
|
||||
f"cc_pair={cc_pair.id}"
|
||||
)
|
||||
|
||||
if not index_attempt_ids:
|
||||
msg = "No new indexing attempts created, indexing jobs are queued or running."
|
||||
logger.info(msg)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=msg,
|
||||
)
|
||||
# run the beat task to pick up the triggers immediately
|
||||
primary_app.send_task(
|
||||
DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
msg = f"Successfully created {len(index_attempt_ids)} index attempts. {index_attempt_ids}"
|
||||
msg = f"Marked {num_triggers} index attempts with indexing triggers."
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message=msg,
|
||||
data=index_attempt_ids,
|
||||
data=num_triggers,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -199,7 +200,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
|
||||
# run the beat task to pick up this deletion from the db immediately
|
||||
primary_app.send_task(
|
||||
"check_for_connector_deletion_task",
|
||||
DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
@@ -45,6 +45,7 @@ class UserPreferences(BaseModel):
|
||||
visible_assistants: list[int] = []
|
||||
recent_assistants: list[int] | None = None
|
||||
default_model: str | None = None
|
||||
auto_scroll: bool | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@@ -79,6 +80,7 @@ class UserInfo(BaseModel):
|
||||
role=user.role,
|
||||
preferences=(
|
||||
UserPreferences(
|
||||
auto_scroll=user.auto_scroll,
|
||||
chosen_assistants=user.chosen_assistants,
|
||||
default_model=user.default_model,
|
||||
hidden_assistants=user.hidden_assistants,
|
||||
@@ -128,6 +130,10 @@ class HiddenUpdateRequest(BaseModel):
|
||||
hidden: bool
|
||||
|
||||
|
||||
class AutoScrollRequest(BaseModel):
|
||||
auto_scroll: bool | None
|
||||
|
||||
|
||||
class SlackBotCreationRequest(BaseModel):
|
||||
name: str
|
||||
enabled: bool
|
||||
@@ -156,6 +162,7 @@ class SlackChannelConfigCreationRequest(BaseModel):
|
||||
channel_name: str
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
show_continue_in_web_ui: bool = False
|
||||
enable_auto_filters: bool = False
|
||||
# If no team members, assume respond in the channel to everyone
|
||||
respond_member_group_list: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -80,6 +80,10 @@ def _form_channel_config(
|
||||
if follow_up_tags is not None:
|
||||
channel_config["follow_up_tags"] = follow_up_tags
|
||||
|
||||
channel_config[
|
||||
"show_continue_in_web_ui"
|
||||
] = slack_channel_config_creation_request.show_continue_in_web_ui
|
||||
|
||||
channel_config[
|
||||
"respond_to_bots"
|
||||
] = slack_channel_config_creation_request.respond_to_bots
|
||||
|
||||
@@ -34,7 +34,6 @@ from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.api_key import is_api_key_email_address
|
||||
@@ -52,6 +51,7 @@ from danswer.db.users import list_users
|
||||
from danswer.db.users import validate_user_role_update
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.manage.models import AllUsersResponse
|
||||
from danswer.server.manage.models import AutoScrollRequest
|
||||
from danswer.server.manage.models import UserByEmail
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
@@ -63,6 +63,7 @@ from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.server.utils import send_user_email_invite
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.configs.app_configs import SUPER_USERS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -497,7 +498,6 @@ def verify_user_logged_in(
|
||||
return fetch_no_auth_user(store)
|
||||
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
@@ -581,6 +581,30 @@ def update_user_recent_assistants(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.patch("/auto-scroll")
|
||||
def update_user_auto_scroll(
|
||||
request: AutoScrollRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.auto_scroll = request.auto_scroll
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(auto_scroll=request.auto_scroll)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.patch("/user/default-model")
|
||||
def update_user_default_model(
|
||||
request: ChosenDefaultModelRequest,
|
||||
|
||||
@@ -27,9 +27,11 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
from danswer.db.chat import add_chats_to_session_from_slack_thread
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import duplicate_chat_session_for_user_from_slack
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
@@ -532,6 +534,38 @@ def seed_chat(
|
||||
)
|
||||
|
||||
|
||||
class SeedChatFromSlackRequest(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
class SeedChatFromSlackResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.post("/seed-chat-session-from-slack")
|
||||
def seed_chat_from_slack(
|
||||
chat_seed_request: SeedChatFromSlackRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeedChatFromSlackResponse:
|
||||
slack_chat_session_id = chat_seed_request.chat_session_id
|
||||
new_chat_session = duplicate_chat_session_for_user_from_slack(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
chat_session_id=slack_chat_session_id,
|
||||
)
|
||||
|
||||
add_chats_to_session_from_slack_thread(
|
||||
db_session=db_session,
|
||||
slack_chat_session_id=slack_chat_session_id,
|
||||
new_chat_session_id=new_chat_session.id,
|
||||
)
|
||||
|
||||
return SeedChatFromSlackResponse(
|
||||
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
|
||||
)
|
||||
|
||||
|
||||
"""File upload"""
|
||||
|
||||
|
||||
@@ -673,14 +707,18 @@ def upload_files_for_chat(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/file/{file_id}")
|
||||
@router.get("/file/{file_id:path}")
|
||||
def fetch_chat_file(
|
||||
file_id: str,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> Response:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
# NOTE: specifying "image/jpeg" here, but it still works for pngs
|
||||
# TODO: do this properly
|
||||
return Response(content=file_io.read(), media_type="image/jpeg")
|
||||
|
||||
return StreamingResponse(file_io, media_type=media_type)
|
||||
|
||||
@@ -79,6 +79,7 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
message: str
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor]
|
||||
|
||||
# If no prompt provided, uses the largest prompt of the chat session
|
||||
# but really this should be explicitly specified, only in the simplified APIs is this inferred
|
||||
# Use prompt_id 0 to use the system default prompt which is Answer-Question
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -38,10 +37,6 @@ basic_router = APIRouter(prefix="/settings")
|
||||
def put_settings(
|
||||
settings: Settings, _: User | None = Depends(current_admin_user)
|
||||
) -> None:
|
||||
try:
|
||||
settings.check_validity()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
|
||||
@@ -41,33 +41,10 @@ class Notification(BaseModel):
|
||||
class Settings(BaseModel):
|
||||
"""General settings"""
|
||||
|
||||
chat_page_enabled: bool = True
|
||||
search_page_enabled: bool = True
|
||||
default_page: PageType = PageType.SEARCH
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: GatingType = GatingType.NONE
|
||||
|
||||
def check_validity(self) -> None:
|
||||
chat_page_enabled = self.chat_page_enabled
|
||||
search_page_enabled = self.search_page_enabled
|
||||
default_page = self.default_page
|
||||
|
||||
if chat_page_enabled is False and search_page_enabled is False:
|
||||
raise ValueError(
|
||||
"One of `search_page_enabled` and `chat_page_enabled` must be True."
|
||||
)
|
||||
|
||||
if default_page == PageType.CHAT and chat_page_enabled is False:
|
||||
raise ValueError(
|
||||
"The default page cannot be 'chat' if the chat page is disabled."
|
||||
)
|
||||
|
||||
if default_page == PageType.SEARCH and search_page_enabled is False:
|
||||
raise ValueError(
|
||||
"The default page cannot be 'search' if the search page is disabled."
|
||||
)
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
|
||||
@@ -38,6 +38,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.seeding.load_docs import seed_initial_documents
|
||||
from danswer.seeding.load_yamls import load_chat_yamls
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.store import load_settings
|
||||
@@ -149,7 +150,7 @@ def setup_danswer(
|
||||
# update multipass indexing setting based on GPU availability
|
||||
update_default_multipass_indexing(db_session)
|
||||
|
||||
# seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
@@ -253,13 +254,14 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls(db_session)
|
||||
|
||||
logger.notice("Loading built-in tools")
|
||||
load_builtin_tools(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
load_chat_yamls(db_session)
|
||||
|
||||
refresh_built_in_tools_cache(db_session)
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc:
|
||||
updated_at=datetime.now(),
|
||||
link=result.link,
|
||||
source_links={0: result.link},
|
||||
match_highlights=[],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,72 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from ee.danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from ee.danswer.configs.app_configs import SUPER_USERS
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.server.seeding import get_seed_config
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_public_key() -> str | None:
|
||||
if JWT_PUBLIC_KEY_URL is None:
|
||||
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
||||
return None
|
||||
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
try:
|
||||
public_key_pem = get_public_key()
|
||||
if public_key_pem is None:
|
||||
logger.error("Failed to retrieve public key")
|
||||
return None
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key_pem,
|
||||
algorithms=["RS256"],
|
||||
audience=None,
|
||||
)
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
except InvalidTokenError:
|
||||
logger.error("Invalid JWT token")
|
||||
get_public_key.cache_clear()
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
get_public_key.cache_clear()
|
||||
return None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
@@ -38,6 +87,13 @@ async def optional_user_(
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# If user is still None, check for JWT in Authorization header
|
||||
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
user = await verify_jwt_token(token, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -4,16 +4,17 @@ from typing import Any
|
||||
from danswer.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": "autogenerate_usage_report_task",
|
||||
"task": DanswerCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": "check_ttl_management_task",
|
||||
"task": DanswerCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
},
|
||||
]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
# Applicable for OIDC Auth
|
||||
@@ -19,3 +20,11 @@ STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
|
||||
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
|
||||
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential__UserGroup
|
||||
@@ -298,6 +299,11 @@ def fetch_user_groups_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[str]]]:
|
||||
"""
|
||||
Fetches all user groups that have access to the given documents.
|
||||
|
||||
NOTE: this doesn't include groups if the cc_pair is access type SYNC
|
||||
"""
|
||||
stmt = (
|
||||
select(Document.id, func.array_agg(UserGroup.name))
|
||||
.join(
|
||||
@@ -306,7 +312,11 @@ def fetch_user_groups_for_documents(
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||
and_(
|
||||
ConnectorCredentialPair.id
|
||||
== UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||
ConnectorCredentialPair.access_type != AccessType.SYNC,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
|
||||
@@ -16,7 +16,7 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
_REQUEST_PAGINATION_LIMIT = 100
|
||||
_REQUEST_PAGINATION_LIMIT = 5000
|
||||
|
||||
|
||||
def _get_server_space_permissions(
|
||||
@@ -97,6 +97,7 @@ def _get_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
@@ -113,6 +114,7 @@ def _get_space_permissions(
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
if is_cloud:
|
||||
@@ -242,6 +244,7 @@ def _fetch_all_page_restrictions_for_space(
|
||||
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
|
||||
|
||||
@@ -254,27 +257,28 @@ def confluence_doc_sync(
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
logger.debug("Starting confluence doc sync")
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
if confluence_connector.confluence_client is None:
|
||||
raise ValueError("Failed to load credentials")
|
||||
confluence_client = confluence_connector.confluence_client
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
space_permissions_by_space_key = _get_space_permissions(
|
||||
confluence_client=confluence_client,
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
is_cloud=is_cloud,
|
||||
)
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
confluence_client=confluence_client,
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,10 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
user = user_result["user"]
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
continue
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
|
||||
@@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 60 seconds
|
||||
DocumentSource.GOOGLE_DRIVE: 60,
|
||||
DocumentSource.CONFLUENCE: 60,
|
||||
# Polling is not supported so we fetch all group permissions every 5 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.main import get_application as get_application_base
|
||||
from danswer.main import include_router_with_global_prefix_prepended
|
||||
from danswer.server.api_key.api import router as api_key_router
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
@@ -116,8 +115,6 @@ def get_application() -> FastAPI:
|
||||
# Analytics endpoints
|
||||
include_router_with_global_prefix_prepended(application, analytics_router)
|
||||
include_router_with_global_prefix_prepended(application, query_history_router)
|
||||
# Api key management
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
|
||||
@@ -113,10 +113,6 @@ async def refresh_access_token(
|
||||
def put_settings(
|
||||
settings: EnterpriseSettings, _: User | None = Depends(current_admin_user)
|
||||
) -> None:
|
||||
try:
|
||||
settings.check_validity()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
|
||||
@@ -157,7 +157,6 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
|
||||
def _seed_settings(settings: Settings) -> None:
|
||||
logger.notice("Seeding Settings")
|
||||
try:
|
||||
settings.check_validity()
|
||||
store_base_settings(settings)
|
||||
logger.notice("Successfully seeded Settings")
|
||||
except ValueError as e:
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import embedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
@@ -205,28 +206,22 @@ class CloudEmbedding:
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -430,6 +425,11 @@ async def process_embed_request(
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
|
||||
@@ -29,7 +29,7 @@ trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.50.2
|
||||
litellm==1.53.1
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
@@ -38,7 +38,7 @@ msal==1.28.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.52.2
|
||||
openai==1.55.3
|
||||
openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user