Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
856c2debd9 add danswer api key header 2024-10-29 13:29:03 -07:00
381 changed files with 6069 additions and 13822 deletions

View File

@@ -1,76 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@v2
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# sarif_file: trivy-results.sarif

View File

@@ -18,9 +18,6 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
jobs:
connectors-check:

View File

@@ -15,7 +15,7 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
model-check:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]

View File

@@ -1,5 +1,4 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
## ✨Contributors
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

View File

@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -38,6 +39,15 @@ RUN apt-get update && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -77,6 +87,7 @@ RUN apt-get update && \
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \

View File

@@ -1,50 +0,0 @@
"""single tool call per message
Revision ID: 33cb72ea4d80
Revises: 5b29123cd710
Create Date: 2024-11-01 12:51:01.535003
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "5b29123cd710"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Step 1: Delete extraneous ToolCall entries
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
op.execute(
sa.text(
"""
DELETE FROM tool_call
WHERE id NOT IN (
SELECT MIN(id)
FROM tool_call
WHERE message_id IS NOT NULL
GROUP BY message_id
);
"""
)
)
# Step 2: Add a unique constraint on message_id
op.create_unique_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
columns=["message_id"],
)
def downgrade() -> None:
# Step 1: Drop the unique constraint on message_id
op.drop_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
type_="unique",
)

View File

@@ -1,70 +0,0 @@
"""nullable search settings for historic index attempts
Revision ID: 5b29123cd710
Revises: 949b4a92a401
Create Date: 2024-10-30 19:37:59.630704
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5b29123cd710"
down_revision = "949b4a92a401"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be nullable
op.alter_column(
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
)
# Add back the foreign key with ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
# Warning: This will delete all index attempts that don't have search settings
op.execute(
"""
DELETE FROM index_attempt
WHERE search_settings_id IS NULL
"""
)
# Drop foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be not nullable
op.alter_column(
"index_attempt",
"search_settings_id",
existing_type=sa.INTEGER(),
nullable=False,
)
# Add back the foreign key without ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)

View File

@@ -93,9 +93,9 @@ from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -510,23 +510,19 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
async def write_token(self, user: User) -> str:
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> TenantAwareJWTStrategy:
def get_jwt_strategy() -> JWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,

View File

@@ -14,19 +14,18 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
@@ -109,27 +108,29 @@ def on_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(tenant_id, int(document_set_id))
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(tenant_id, int(usergroup_id))
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDelete.PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@@ -172,30 +173,44 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Exit early if multi-tenant since primary worker check not needed
if MULTI_TENANT:
return
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=None)
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for all tenant primary workers to be ready...")
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)
if all_tenants_ready:
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)
logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"Not all tenant primary workers were ready within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
@@ -203,7 +218,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
logger.info("All tenant primary workers are ready. Continuing...")
return
@@ -215,20 +230,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)
def on_setup_logging(

View File

@@ -3,144 +3,26 @@ from typing import Any
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger(__name__)
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating tenant task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Tenant task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting tenant task update process")
try:
logger.info("Fetching all tenant IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = self.schedule.items()
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
# celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
@@ -153,4 +35,68 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.conf.beat_scheduler = DynamicTenantScheduler
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -85,6 +85,5 @@ celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
]
)

View File

@@ -13,15 +13,21 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -73,45 +79,91 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as the primary celery worker.")
if MULTI_TENANT:
return
sender.primary_worker_locks = {}
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# by the primary worker
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
# tacking on our own user data to the sender
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
@worker_ready.connect
@@ -164,36 +216,52 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
lock = worker.primary_worker_lock
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
r = get_redis_client(tenant_id=None)
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
worker.primary_worker_lock = lock
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception:
task_logger.exception("Periodic task failed.")

View File

@@ -1,96 +0,0 @@
from datetime import timedelta
from typing import Any
from celery.beat import PersistentScheduler # type: ignore
from celery.utils.log import get_task_logger
from danswer.db.engine import get_all_tenant_ids
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = get_task_logger(__name__)
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=1)
self._last_reload = self.app.now() - self._reload_interval
def setup_schedule(self) -> None:
super().setup_schedule()
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reloading schedule to check for new tenants...")
self._update_tenant_tasks()
self._last_reload = now
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Checking for tenant task updates...")
try:
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = getattr(self, "_store", {"entries": {}}).get(
"entries", {}
)
existing_tenants = set()
for task_name in current_schedule.keys():
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Found new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Updating schedule",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
if not hasattr(self, "_store"):
self._store: dict[str, dict] = {"entries": {}}
self.update_from_dict(new_beat_schedule)
logger.info(f"New schedule: {new_beat_schedule}")
logger.info("Tenant tasks updated successfully")
else:
logger.debug("No schedule updates needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
current_tasks = set(current_schedule.keys())
new_tasks = set(new_schedule.keys())
return current_tasks != new_tasks

View File

@@ -1,10 +1,568 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.configs.base import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def is_indexing(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:

View File

@@ -4,6 +4,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
@@ -17,7 +18,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -40,14 +41,14 @@ def _get_deletion_status(
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
)

View File

@@ -1,48 +0,0 @@
from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -10,6 +10,13 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
@@ -18,8 +25,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_pool import get_redis_client
@@ -57,7 +62,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcs = RedisConnectorStop(cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
@@ -66,10 +71,10 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
redis_connector.stop.set_fence(True)
r.set(rcs.fence_key, cc_pair_id)
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
r.delete(rcs.fence_key)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -101,10 +106,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
# don't generate sync tasks if tasks are still pending
if redis_connector.delete.fenced:
if r.exists(rcd.fence_key):
return None
# we need to load the state of the object inside the fence
@@ -118,49 +123,47 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletionFenceData(
fence_value = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
redis_connector.delete.set_fence(fence_payload)
r.set(rcd.fence_key, fence_value.model_dump_json())
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
if redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.delete.generate_tasks(
app, db_session, lock_beat
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
raise
except Exception:
task_logger.exception("Unexpected exception")
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
@@ -175,7 +178,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
# set this only after all tasks have been added
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
fence_value.num_tasks = tasks_generated
r.set(rcd.fence_key, fence_value.model_dump_json())
return tasks_generated

View File

@@ -2,6 +2,8 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
@@ -12,6 +14,12 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
@@ -42,8 +50,6 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -99,22 +105,19 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
return None
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if current_search_settings.provider_type is None and not MULTI_TENANT:
if old_search_settings:
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
# only warm up if search settings were changed
warm_up_bi_encoder(
embedding_model=embedding_model,
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -123,7 +126,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
@@ -136,10 +138,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
rci = RedisConnectorIndexing(
cc_pair_id, search_settings_instance.id
)
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -302,15 +304,15 @@ def try_creating_indexing_task(
return None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -318,17 +320,19 @@ def try_creating_indexing_task(
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# set a basic fence to start
payload = RedisConnectorIndexingFenceData(
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_fence(payload)
r.set(rci.fence_key, fence_value.model_dump_json())
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
@@ -341,8 +345,6 @@ def try_creating_indexing_task(
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -359,12 +361,11 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
fence_value.index_attempt_id = index_attempt_id
fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
redis_connector_index.set_fence(payload)
r.delete(rci.fence_key)
task_logger.exception(
f"Unexpected exception: "
f"tenant={tenant_id} "
@@ -387,12 +388,7 @@ def connector_indexing_proxy_task(
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
client = SimpleJobClient()
job = client.submit(
@@ -406,56 +402,29 @@ def connector_indexing_proxy_task(
)
if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
while True:
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"error={job.exception()}"
)
if job.status == "error":
logger.error(job.exception())
job.release()
break
job.release()
break
task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
@@ -477,78 +446,78 @@ def connector_indexing_task(
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
logger.info(
f"Indexing spawned task starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
attempt = None
n_final_progress: int | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
n_final_progress = 0
r = get_redis_client(tenant_id=tenant_id)
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
f"fence={rcd.fence_key}"
)
if redis_connector.stop.fenced:
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
f"fence={rcs.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
)
payload = redis_connector_index.payload
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
task_logger.info(
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
)
sleep(1)
continue
logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
task_logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
)
break
lock = r.lock(
redis_connector_index.generator_lock_key,
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -576,19 +545,11 @@ def connector_indexing_task(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# define a callback class
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
r,
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
rcs.fence_key, rci.generator_progress_key, lock, r
)
run_indexing_entrypoint(
@@ -600,29 +561,27 @@ def connector_indexing_task(
)
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return n_final_progress

View File

@@ -11,6 +11,9 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
@@ -30,9 +33,7 @@ from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import pruning_ctx
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -145,11 +146,8 @@ def try_creating_prune_generator_task(
is used to trigger prunes immediately, e.g. via the web ui.
"""
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not ALLOW_SIMULTANEOUS_PRUNING:
count = redis_connector.prune.get_active_task_count()
if count > 0:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
LOCK_TIMEOUT = 30
@@ -166,10 +164,15 @@ def try_creating_prune_generator_task(
return None
try:
if redis_connector.prune.fenced: # skip pruning if already pruning
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
# skip pruning if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -177,10 +180,10 @@ def try_creating_prune_generator_task(
return None
# add a long running generator task to the queue
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
@@ -196,7 +199,7 @@ def try_creating_prune_generator_task(
)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
@@ -226,17 +229,12 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
pruning_ctx_dict = pruning_ctx.get()
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
@@ -269,11 +267,10 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
rcs.fence_key, rcp.generator_progress_key, lock, r
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
@@ -300,29 +297,31 @@ def connector_pruning_generator_task(
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
)
redis_connector.prune.generator_complete = tasks_generated
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():

View File

@@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@@ -0,0 +1,10 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None

View File

@@ -19,6 +19,18 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -55,14 +67,7 @@ from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
@@ -187,7 +192,7 @@ def try_generate_stale_document_sync_tasks(
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
@@ -223,10 +228,10 @@ def try_generate_document_set_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(tenant_id, document_set_id)
rds = RedisDocumentSet(document_set_id)
# don't generate document set sync tasks if tasks are still pending
if rds.fenced:
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
@@ -264,7 +269,7 @@ def try_generate_document_set_sync_tasks(
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
r.set(rds.fence_key, tasks_generated)
return tasks_generated
@@ -278,9 +283,10 @@ def try_generate_user_group_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(tenant_id, usergroup_id)
if rug.fenced:
# don't generate sync tasks if tasks are still pending
rug = RedisUserGroup(usergroup_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
@@ -320,7 +326,7 @@ def try_generate_user_group_sync_tasks(
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
r.set(rug.fence_key, tasks_generated)
return tasks_generated
@@ -346,7 +352,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -356,12 +362,16 @@ def monitor_document_set_taskset(
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(tenant_id, document_set_id)
if not rds.fenced:
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
initial_count = rds.payload
if initial_count is None:
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
@@ -389,38 +399,48 @@ def monitor_document_set_taskset(
f"Successfully synced document set: document_set={document_set_id}"
)
rds.reset()
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
return
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
)
if remaining > 0:
if count > 0:
return
with get_session_with_tenant(tenant_id) as db_session:
@@ -504,15 +524,15 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.taskset_clear()
redis_connector.delete.set_fence(None)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
@@ -521,37 +541,46 @@ def monitor_ccpair_pruning_taskset(
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
initial = redis_connector.prune.generator_complete
if initial is None:
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
remaining = redis_connector.prune.get_remaining()
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if remaining > 0:
if count > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
@@ -566,37 +595,53 @@ def monitor_ccpair_indexing_taskset(
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
payload = redis_connector_index.payload
if not payload:
return
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
status_int = redis_connector_index.get_completion()
if status_int is None:
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
@@ -607,7 +652,7 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt=index_attempt,
@@ -615,10 +660,22 @@ def monitor_ccpair_indexing_taskset(
failure_reason="Connector indexing aborted or exceptioned.",
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus(status_int)
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
@@ -627,7 +684,11 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
@@ -639,7 +700,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False if it exited early to prevent overlap
Returns True if the task actually did work, False
"""
r = get_redis_client(tenant_id=tenant_id)
@@ -690,12 +751,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(
RedisConnectorIndex.fence_key_with_ids(
a.connector_credential_pair_id, a.search_settings_id
)
):
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
@@ -703,15 +763,15 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
@@ -722,19 +782,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app = fetch_versioned_implementation(
"danswer.background.celery.apps.beat", "celery_app"
)

View File

@@ -118,13 +118,7 @@ def _run_indexing(
"""
start_time = time.time()
if index_attempt.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs

View File

@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):

View File

@@ -18,7 +18,6 @@ from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
@@ -78,49 +77,31 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchResponse,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
@@ -279,7 +260,6 @@ ChatPacket = (
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| StreamStopInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -552,13 +532,6 @@ def stream_chat_message_objects(
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
)
# find out what tools to use
search_tool: SearchTool | None = None
@@ -577,16 +550,13 @@ def stream_chat_message_objects(
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -656,11 +626,7 @@ def stream_chat_message_objects(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
InternetSearchTool(api_key=bing_api_key)
]
continue
@@ -701,7 +667,13 @@ def stream_chat_message_objects(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=answer_style_config,
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(
llm
@@ -805,8 +777,7 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
pass
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
@@ -836,7 +807,6 @@ def stream_chat_message_objects(
# Post-LLM answer processing
try:
logger.debug("Post-LLM answer processing")
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
message_specific_citations = _translate_citations(
@@ -864,15 +834,17 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else None
else []
),
)

View File

@@ -9,19 +9,19 @@ prompts:
system: >
You are a question answering system that is constantly learning and improving.
The current date is DANSWER_DATETIME_REPLACEMENT.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
You always clearly communicate ANY UNCERTAINTY in your answer.
# Task Prompt (as shown in UI)
task: >
Answer my query based on the documents provided.
The documents may not all be relevant, ignore any documents that are not directly relevant
to the most recent user query.
I have not read or seen any of the documents and do not want to read them.
If there are no relevant documents, refer to the chat history and your internal knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
@@ -30,21 +30,21 @@ prompts:
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "ImageGeneration"
description: "Generates images from user descriptions!"
description: "Generates images based on user prompts!"
system: >
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
For appropriate requests, you will generate an image that matches the user's requirements.
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
You aim to be helpful while maintaining appropriate content standards.
You are an advanced image generation system capable of creating diverse and detailed images.
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
task: >
Based on the user's description, create a high-quality image that accurately reflects their request.
Pay close attention to the specified details, styles, and desired elements.
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true
include_citations: false
@@ -64,13 +64,14 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Summarize"
description: "Summarize relevant information from retrieved context!"
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
The current date is DANSWER_DATETIME_REPLACEMENT.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
@@ -83,6 +84,7 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Paraphrase"
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
@@ -90,10 +92,10 @@ prompts:
The current date is DANSWER_DATETIME_REPLACEMENT.
You only provide quotes that are EXACT substrings from provided documents!
If there are no documents provided,
simply tell the user that there are no documents to reference.
You NEVER generate new text or phrases outside of the citation.
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
task: >

View File

@@ -251,6 +251,9 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# for some connectors
ENABLE_EXPENSIVE_EXPERT_CALLS = False
GOOGLE_DRIVE_INCLUDE_SHARED = False
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
# TODO these should be available for frontend configuration, via advanced options expandable
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
@@ -478,7 +481,3 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")

View File

@@ -125,7 +125,6 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]

View File

@@ -17,7 +17,6 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
@@ -250,11 +249,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")

View File

@@ -23,16 +23,7 @@ def datetime_to_utc(dt: datetime) -> datetime:
def time_str_to_utc(datetime_str: str) -> datetime:
try:
dt = parse(datetime_str)
except ValueError:
# Handle malformed timezone by attempting to fix common format issues
if "0000" in datetime_str:
# Convert "0000" to "+0000" for proper timezone parsing
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
dt = parse(fixed_dt_str)
else:
raise
dt = parse(datetime_str)
return datetime_to_utc(dt)

View File

@@ -16,7 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
from danswer.connectors.gmail.connector import GmailConnector
@@ -100,7 +99,6 @@ def identify_connector_class(
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -1,239 +0,0 @@
import json
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FRESHDESK_ID_PREFIX = "FRESHDESK_"
_TICKET_FIELDS_TO_INCLUDE = {
"fr_escalated",
"spam",
"priority",
"source",
"status",
"type",
"is_escalated",
"tags",
"nr_due_by",
"nr_escalated",
"cc_emails",
"fwd_emails",
"reply_cc_emails",
"ticket_cc_emails",
"support_email",
"to_emails",
}
_SOURCE_NUMBER_TYPE_MAP: dict[int, str] = {
1: "Email",
2: "Portal",
3: "Phone",
7: "Chat",
9: "Feedback Widget",
10: "Outbound Email",
}
_PRIORITY_NUMBER_TYPE_MAP: dict[int, str] = {
1: "low",
2: "medium",
3: "high",
4: "urgent",
}
_STATUS_NUMBER_TYPE_MAP: dict[int, str] = {
2: "open",
3: "pending",
4: "resolved",
5: "closed",
}
def _create_metadata_from_ticket(ticket: dict) -> dict:
metadata: dict[str, str | list[str]] = {}
# Combine all emails into a list so there are no repeated emails
email_data: set[str] = set()
for key, value in ticket.items():
# Skip fields that aren't useful for embedding
if key not in _TICKET_FIELDS_TO_INCLUDE:
continue
# Skip empty fields
if not value or value == "[]":
continue
# Convert strings or lists to strings
stringified_value: str | list[str]
if isinstance(value, list):
stringified_value = [str(item) for item in value]
else:
stringified_value = str(value)
if "email" in key:
if isinstance(stringified_value, list):
email_data.update(stringified_value)
else:
email_data.add(stringified_value)
else:
metadata[key] = stringified_value
if email_data:
metadata["emails"] = list(email_data)
# Convert source numbers to human-parsable string
if source_number := ticket.get("source"):
metadata["source"] = _SOURCE_NUMBER_TYPE_MAP.get(
source_number, "Unknown Source Type"
)
# Convert priority numbers to human-parsable string
if priority_number := ticket.get("priority"):
metadata["priority"] = _PRIORITY_NUMBER_TYPE_MAP.get(
priority_number, "Unknown Priority"
)
# Convert status to human-parsable string
if status_number := ticket.get("status"):
metadata["status"] = _STATUS_NUMBER_TYPE_MAP.get(
status_number, "Unknown Status"
)
due_by = datetime.fromisoformat(ticket["due_by"].replace("Z", "+00:00"))
metadata["overdue"] = str(datetime.now(timezone.utc) > due_by)
return metadata
def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
# Use the ticket description as the text
text = f"Ticket description: {parse_html_page_basic(ticket.get('description_text', ''))}"
metadata = _create_metadata_from_ticket(ticket)
# This is also used in the ID because it is more unique than the just the ticket ID
link = f"https://{domain}.freshdesk.com/helpdesk/tickets/{ticket['id']}"
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
Section(
link=link,
text=text,
)
],
source=DocumentSource.FRESHDESK,
semantic_identifier=ticket["subject"],
metadata=metadata,
doc_updated_at=datetime.fromisoformat(
ticket["updated_at"].replace("Z", "+00:00")
),
)
class FreshdeskConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str | int]) -> None:
api_key = credentials.get("freshdesk_api_key")
domain = credentials.get("freshdesk_domain")
password = credentials.get("freshdesk_password")
if not all(isinstance(cred, str) for cred in [domain, api_key, password]):
raise ConnectorMissingCredentialError(
"All Freshdesk credentials must be strings"
)
self.api_key = str(api_key)
self.domain = str(domain)
self.password = str(password)
def _fetch_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> Iterator[List[dict]]:
"""
'end' is not currently used, so we may double fetch tickets created after the indexing
starts but before the actual call is made.
To use 'end' would require us to use the search endpoint but it has limitations,
namely having to fetch all IDs and then individually fetch each ticket because there is no
'include' field available for this endpoint:
https://developers.freshdesk.com/api/#filter_tickets
"""
if self.api_key is None or self.domain is None or self.password is None:
raise ConnectorMissingCredentialError("freshdesk")
base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets"
params: dict[str, int | str] = {
"include": "description",
"per_page": 50,
"page": 1,
}
if start:
params["updated_since"] = start.isoformat()
while True:
response = requests.get(
base_url, auth=(self.api_key, self.password), params=params
)
response.raise_for_status()
if response.status_code == 204:
break
tickets = json.loads(response.content)
logger.info(
f"Fetched {len(tickets)} tickets from Freshdesk API (Page {params['page']})"
)
yield tickets
if len(tickets) < int(params["per_page"]):
break
params["page"] = int(params["page"]) + 1
def _process_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for ticket_batch in self._fetch_tickets(start, end):
for ticket in ticket_batch:
doc_batch.append(_create_doc_from_ticket(ticket, self.domain))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_tickets()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_tickets(start_datetime, end_datetime)

View File

@@ -1,8 +1,4 @@
import re
import time
from base64 import urlsafe_b64decode
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
@@ -10,7 +6,6 @@ from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
@@ -39,64 +34,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
attempt = 0
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...
try:
return request.execute()
except HttpError as error:
attempt += 1
if error.resp.status == 429:
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
sleep_time = int(retry_after)
else:
# Extract 'Retry after' timestamp from error message
match = re.search(
r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)",
str(error),
)
if match:
retry_after_timestamp = match.group(1)
retry_after_dt = datetime.strptime(
retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ"
).replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
sleep_time = max(
int((retry_after_dt - current_time).total_seconds()),
0,
)
else:
logger.error(
f"No Retry-After header or timestamp found in error message: {error}"
)
sleep_time = 60
sleep_time += 3 # Add a buffer to be safe
logger.info(
f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds."
)
time.sleep(sleep_time)
else:
raise
# If we've exhausted all attempts
raise Exception(f"Failed to execute request after {max_attempts} attempts")
class GmailConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
@@ -219,7 +156,7 @@ class GmailConnector(LoadConnector, PollConnector):
query = GmailConnector._build_time_range_query(time_range_start, time_range_end)
service = discovery.build("gmail", "v1", credentials=self.creds)
while page_token is not None:
result = _execute_with_retry(
result = (
service.users()
.messages()
.list(
@@ -228,17 +165,18 @@ class GmailConnector(LoadConnector, PollConnector):
q=query,
maxResults=self.batch_size,
)
.execute()
)
page_token = result.get("nextPageToken")
messages = result.get("messages", [])
doc_batch = []
for message in messages:
message_id = message["id"]
msg = _execute_with_retry(
msg = (
service.users()
.messages()
.get(userId="me", id=message_id, format="full")
.execute()
)
doc = self._email_to_document(msg)
doc_batch.append(doc)

View File

@@ -1,305 +1,556 @@
import io
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from enum import Enum
from itertools import chain
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.connectors.google_drive.connector_auth import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_drive.constants import SCOPE_DOC_URL
from danswer.connectors.google_drive.constants import SLIM_BATCH_SIZE
from danswer.connectors.google_drive.constants import USER_FIELDS
from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import SlimDocument
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
if not string:
return []
return [s.strip() for s in string.split(",") if s.strip()]
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
return [url.split("/")[-1] for url in urls]
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
GoogleDriveFileType = dict[str, Any]
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _run_drive_file_query(
service: discovery.Resource,
query: str,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
results = add_retries(
lambda: (
service.files()
.list(
corpora="allDrives"
if include_shared
else "user", # needed to search through shared drives
pageSize=batch_size,
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
fields=(
"nextPageToken, files(mimeType, id, name, permissions, "
"modifiedTime, webViewLink, shortcutDetails)"
),
pageToken=next_page_token,
q=query,
)
.execute()
)
)()
next_page_token = results.get("nextPageToken")
files = results["files"]
for file in files:
if follow_shortcuts and "shortcutDetails" in file:
try:
file_shortcut_points_to = add_retries(
lambda: (
service.files()
.get(
fileId=file["shortcutDetails"]["targetId"],
supportsAllDrives=include_shared,
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
)
.execute()
)
)()
yield file_shortcut_points_to
except HttpError:
logger.error(
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
)
if continue_on_failure:
continue
raise
else:
yield file
def _get_folder_id(
service: discovery.Resource,
parent_id: str,
folder_name: str,
include_shared: bool,
follow_shortcuts: bool,
) -> str | None:
"""
Get the ID of a folder given its name and the ID of its parent folder.
"""
query = f"'{parent_id}' in parents and name='{folder_name}' and "
if follow_shortcuts:
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
else:
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
# TODO: support specifying folder path in shared drive rather than just `My Drive`
results = add_retries(
lambda: (
service.files()
.list(
q=query,
spaces="drive",
fields="nextPageToken, files(id, name, shortcutDetails)",
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
)
.execute()
)
)()
items = results.get("files", [])
folder_id = None
if items:
if follow_shortcuts and "shortcutDetails" in items[0]:
folder_id = items[0]["shortcutDetails"]["targetId"]
else:
folder_id = items[0]["id"]
return folder_id
def _get_folders(
service: discovery.Resource,
continue_on_failure: bool,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if follow_shortcuts:
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
for file in _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
):
# Need to check this since file may have been a target of a shortcut
# and not necessarily a folder
if file["mimeType"] == DRIVE_FOLDER_TYPE:
yield file
else:
pass
def _get_files(
service: discovery.Resource,
continue_on_failure: bool,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
query += f"and modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
files = _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
return files
def get_all_files_batched(
service: discovery.Resource,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
# if True, will fetch files in sub-folders of the specified folder ID.
# Only applies if folder_id is specified.
traverse_subfolders: bool = True,
folder_ids_traversed: list[str] | None = None,
) -> Iterator[list[GoogleDriveFileType]]:
"""Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`.
"""
found_files = _get_files(
service=service,
continue_on_failure=continue_on_failure,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=folder_id,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
yield from batch_generator(
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.debug(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
if traverse_subfolders and folder_id is not None:
folder_ids_traversed = folder_ids_traversed or []
subfolders = _get_folders(
service=service,
folder_id=folder_id,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
for subfolder in subfolders:
if subfolder["id"] not in folder_ids_traversed:
logger.info("Fetching all files in subfolder: " + subfolder["name"])
folder_ids_traversed.append(subfolder["id"])
yield from get_all_files_batched(
service=service,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=subfolder["id"],
traverse_subfolders=traverse_subfolders,
folder_ids_traversed=folder_ids_traversed,
)
else:
logger.debug(
"Skipping subfolder since already traversed: " + subfolder["name"]
)
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
class GoogleDriveConnector(LoadConnector, PollConnector):
def __init__(
self,
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
# OLD PARAMETERS
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
# if specified, will only index files in these folders
folder_paths: list[str] | None = None,
include_shared: bool | None = None,
follow_shortcuts: bool | None = None,
only_org_public: bool | None = None,
continue_on_failure: bool | None = None,
batch_size: int = INDEX_BATCH_SIZE,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
# Check for old input parameters
if (
folder_paths is not None
or include_shared is not None
or follow_shortcuts is not None
or only_org_public is not None
or continue_on_failure is not None
):
logger.exception(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ValueError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
if (
not include_shared_drives
and not include_my_drives
and not shared_folder_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
)
self.folder_paths = folder_paths or []
self.batch_size = batch_size
self.include_shared_drives = include_shared_drives
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self.shared_drive_ids = _extract_ids_from_urls(shared_drive_url_list)
self.include_my_drives = include_my_drives
self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails)
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list)
self.primary_admin_email: str | None = None
self.google_domain: str | None = None
self.include_shared = include_shared
self.follow_shortcuts = follow_shortcuts
self.only_org_public = only_org_public
self.continue_on_failure = continue_on_failure
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._TRAVERSED_PARENT_IDS: set[str] = set()
@staticmethod
def _process_folder_paths(
service: discovery.Resource,
folder_paths: list[str],
include_shared: bool,
follow_shortcuts: bool,
) -> list[str]:
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
folder_ids: list[str] = []
for path in folder_paths:
folder_names = path.split("/")
parent_id = "root"
for folder_name in folder_names:
found_parent_id = _get_folder_id(
service=service,
parent_id=parent_id,
folder_name=folder_name,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
)
if found_parent_id is None:
raise ValueError(
(
f"Folder '{folder_name}' in path '{path}' "
"not found in Google Drive"
)
)
parent_id = found_parent_id
folder_ids.append(parent_id)
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._TRAVERSED_PARENT_IDS.add(folder_id)
return folder_ids
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self.google_domain = primary_admin_email.split("@")[1]
self.primary_admin_email = primary_admin_email
self.creds, new_creds_dict = get_google_drive_creds(credentials)
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
return new_creds_dict
def get_google_resource(
self,
service_name: str = "drive",
service_version: str = "v3",
user_email: str | None = None,
) -> Resource:
if isinstance(self.creds, ServiceAccountCredentials):
creds = self.creds.with_subject(user_email or self.primary_admin_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(self.creds, OAuthCredentials):
service = build(service_name, service_version, credentials=self.creds)
else:
raise PermissionError("No credentials found")
return service
def _get_all_user_emails(self) -> list[str]:
admin_service = self.get_google_resource("admin", "directory_v1")
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
primary_drive_service = self.get_google_resource()
if self.include_shared_drives:
shared_drive_urls = self.shared_drive_ids
if not shared_drive_urls:
# if no parent ids are specified, get all shared drives using the admin account
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=True,
fields="drives(id)",
):
shared_drive_urls.append(drive["id"])
# For each shared drive, retrieve all files
for shared_drive_id in shared_drive_urls:
for file in get_files_in_shared_drive(
service=primary_drive_service,
drive_id=shared_drive_id,
is_slim=is_slim,
cache_folders=bool(self.shared_folder_ids),
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
):
yield file
if self.shared_folder_ids:
# Crawl all the shared parent ids for files
for folder_id in self.shared_folder_ids:
yield from crawl_folders_for_files(
service=primary_drive_service,
parent_id=folder_id,
personal_drive=False,
traversed_parent_ids=self._TRAVERSED_PARENT_IDS,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
all_user_emails = []
# get all personal docs from each users' personal drive
if self.include_my_drives:
if isinstance(self.creds, ServiceAccountCredentials):
all_user_emails = self.my_drive_emails or []
# If using service account and no emails specified, fetch all users
if not all_user_emails:
all_user_emails = self._get_all_user_emails()
elif self.primary_admin_email:
# If using OAuth, only fetch the primary admin email
all_user_emails = [self.primary_admin_email]
for email in all_user_emails:
logger.info(f"Fetching personal files for user: {email}")
user_drive_service = self.get_google_resource(user_email=email)
yield from get_files_in_my_drive(
service=user_drive_service,
email=email,
is_slim=is_slim,
start=start,
end=end,
)
def _extract_docs_from_google_drive(
def _fetch_docs_from_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
doc_batch = []
for file in self._fetch_drive_items(
is_slim=False,
start=start,
end=end,
):
user_email = file.get("owners", [{}])[0].get("emailAddress")
service = self.get_google_resource(user_email=user_email)
if doc := convert_drive_item_to_document(
file=file,
service=service,
):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if self.creds is None:
raise PermissionError("Not logged into Google Drive")
yield doc_batch
service = discovery.build("drive", "v3", credentials=self.creds)
folder_ids: Sequence[str | None] = self._process_folder_paths(
service, self.folder_paths, self.include_shared, self.follow_shortcuts
)
if not folder_ids:
folder_ids = [None]
file_batches = chain(
*[
get_all_files_batched(
service=service,
continue_on_failure=self.continue_on_failure,
include_shared=self.include_shared,
follow_shortcuts=self.follow_shortcuts,
batch_size=self.batch_size,
time_range_start=start,
time_range_end=end,
folder_id=folder_id,
traverse_subfolders=True,
)
for folder_id in folder_ids
]
)
for files_batch in file_batches:
doc_batch = []
for file in files_batch:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
continue
if self.only_org_public:
if "permissions" not in file:
continue
if not any(
permission["type"] == "domain"
for permission in file["permissions"]
):
continue
try:
text_contents = extract_text(file, service) or ""
except HttpError as e:
reason = (
e.error_details[0]["reason"]
if e.error_details
else e.reason
)
message = (
e.error_details[0]["message"]
if e.error_details
else e.reason
)
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
continue
raise
doc_batch.append(
Document(
id=file["webViewLink"],
sections=[
Section(link=file["webViewLink"], text=text_contents)
],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception(
"Ran into exception when pulling a file from Google Drive"
)
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
yield from self._fetch_docs_from_drive()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
# need to subtract 10 minutes from start time to account for modifiedTime
# propogation if a document is modified, it takes some time for the API to
# reflect these changes if we do not have an offset, then we may "miss" the
# update when polling
yield from self._fetch_docs_from_drive(start, end)
def _extract_slim_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
is_slim=True,
start=start,
end=end,
):
slim_batch.append(
SlimDocument(
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
},
)
)
if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch
slim_batch = []
yield slim_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._extract_slim_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
if __name__ == "__main__":
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
)
with open(service_account_json_path) as f:
creds = json.load(f)
credentials_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break

View File

@@ -8,16 +8,24 @@ from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from googleapiclient.discovery import build # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
@@ -28,14 +36,15 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
GOOGLE_DRIVE_SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
]
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_drive_primary_admin"
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
def _build_frontend_google_drive_redirect() -> str:
@@ -43,7 +52,7 @@ def _build_frontend_google_drive_redirect() -> str:
def get_google_drive_creds_for_authorized_user(
token_json_str: str, scopes: list[str]
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
@@ -63,15 +72,21 @@ def get_google_drive_creds_for_authorized_user(
return None
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = GOOGLE_DRIVE_SCOPES
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
@@ -85,27 +100,26 @@ def get_google_drive_creds(
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
}
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
elif KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=scopes
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
)
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(
"Unable to access Google Drive - service account credentials are invalid."
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
@@ -132,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=GOOGLE_DRIVE_SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
@@ -155,34 +169,13 @@ def update_credential_access_tokens(
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=GOOGLE_DRIVE_SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
# Get user email from Google API so we know who
# the primary admin is for this connector
try:
admin_service = build("drive", "v3", credentials=creds)
user_info = (
admin_service.about()
.get(
fields="user(emailAddress)",
)
.execute()
)
email = user_info.get("user", {}).get("emailAddress")
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
}
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
@@ -191,15 +184,15 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
primary_admin_email: str | None = None,
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
credential_dict = {
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: service_account_key.json(),
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
credential_json=credential_dict,

View File

@@ -1,36 +1,7 @@
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress))"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
"permissionIds, webViewLink, owners(emailAddress))"
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
# Documentation and error messages
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
ONYX_SCOPE_INSTRUCTIONS = (
"You have upgraded Danswer without updating the Google Drive scopes. "
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
)
# Batch sizes
SLIM_BATCH_SIZE = 500
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]

View File

@@ -1,115 +0,0 @@
import io
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import ERRORS_TO_CONTINUE_ON
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _extract_text(file: dict[str, str], service: Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
def convert_drive_item_to_document(
file: GoogleDriveFileType, service: Resource
) -> Document | None:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
try:
text_contents = _extract_text(file, service) or ""
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
return Document(
id=file["webViewLink"],
sections=[Section(link=file["webViewLink"], text=text_contents)],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
except Exception as e:
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise e
logger.exception("Ran into exception when pulling a file from Google Drive")
return None

View File

@@ -1,192 +0,0 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from googleapiclient.discovery import Resource # type: ignore
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import FILE_FIELDS
from danswer.connectors.google_drive.constants import FOLDER_FIELDS
from danswer.connectors.google_drive.constants import SLIM_FILE_FIELDS
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += f" and modifiedTime >= '{time_start}'"
if end is not None:
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and modifiedTime <= '{time_stop}'"
return time_range_filter
def _get_folders_in_parent(
service: Resource,
parent_id: str | None = None,
personal_drive: bool = False,
) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
query += " and trashed = false"
if parent_id:
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user" if personal_drive else "allDrives",
supportsAllDrives=not personal_drive,
includeItemsFromAllDrives=not personal_drive,
fields=FOLDER_FIELDS,
q=query,
):
yield file
def _get_files_in_parent(
service: Resource,
parent_id: str,
personal_drive: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user" if personal_drive else "allDrives",
supportsAllDrives=not personal_drive,
includeItemsFromAllDrives=not personal_drive,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
def crawl_folders_for_files(
service: Resource,
parent_id: str,
personal_drive: bool,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
print(f"Skipping subfolder since already traversed: {parent_id}")
return
update_traversed_ids_func(parent_id)
yield from _get_files_in_parent(
service=service,
personal_drive=personal_drive,
start=start,
end=end,
parent_id=parent_id,
)
for subfolder in _get_folders_in_parent(
service=service,
parent_id=parent_id,
personal_drive=personal_drive,
):
logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
personal_drive=personal_drive,
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
end=end,
)
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
cache_folders: bool = True,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
if cache_folders:
# Get all folders being queried and add them to the traversed set
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=query,
):
update_traversed_ids_func(file["id"])
# Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
def get_files_in_my_drive(
service: Resource,
email: str,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
# Just in case we need to get the root folder id
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]

View File

@@ -1,35 +0,0 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.utils.retry_wrapper import retry_builder
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType]:
"""Execute a paginated retrieval from Google Drive API
Args:
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
results = add_retries(lambda: retrieval_function(**request_kwargs).execute())()
next_page_token = results.get("nextPageToken")
for item in results.get(list_key, []):
yield item

View File

@@ -1,18 +0,0 @@
from enum import Enum
from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
GoogleDriveFileType = dict[str, Any]

View File

@@ -56,11 +56,7 @@ class PollConnector(BaseConnector):
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import builtins
import functools
import itertools
import tempfile
from typing import Any
from unittest import mock
from urllib.parse import urlparse
@@ -19,8 +18,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
@mock.patch.object(
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import datetime
import itertools
import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
@@ -26,8 +25,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
def pywikibot_timestamp_to_utc_datetime(
timestamp: pywikibot.time.Timestamp,
@@ -124,6 +121,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [

View File

@@ -251,11 +251,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []

View File

@@ -391,11 +391,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

View File

@@ -1,7 +1,10 @@
from collections.abc import Iterator
from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects import Ticket # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
@@ -17,244 +20,43 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.retry_wrapper import retry_builder
MAX_PAGE_SIZE = 30 # Zendesk API maximum
class ZendeskCredentialsNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__(
"Zendesk Credentials are not set up, was load_credentials called?"
)
class ZendeskClient:
def __init__(self, subdomain: str, email: str, token: str):
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
self.auth = (f"{email}/token", token)
@retry_builder()
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
)
response.raise_for_status()
return response.json()
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
content_tags: dict[str, str] = {}
params = {"page[size]": MAX_PAGE_SIZE}
try:
while True:
data = client.make_request("guide/content_tags", params)
for tag in data.get("records", []):
content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
return content_tags
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def _get_articles(
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
) -> Iterator[dict[str, Any]]:
params = (
{"start_time": start_time, "page[size]": page_size}
if start_time
else {"page[size]": page_size}
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
)
update_time = time_str_to_utc(article.updated_at)
while True:
data = client.make_request("help_center/articles", params)
for article in data["articles"]:
yield article
if not data.get("meta", {}).get("has_more"):
break
params["page[after]"] = data["meta"]["after_cursor"]
def _get_tickets(
client: ZendeskClient, start_time: int | None = None
) -> Iterator[dict[str, Any]]:
params = {"start_time": start_time} if start_time else {"start_time": 0}
while True:
data = client.make_request("incremental/tickets.json", params)
for ticket in data["tickets"]:
yield ticket
if not data.get("end_of_stream", False):
params["start_time"] = data["end_time"]
else:
break
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
def _article_to_document(
article: dict[str, Any],
content_tags: dict[str, str],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
author_id = article.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
updated_at = article.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
# Build metadata
# build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.get("label_names", []) if label],
"labels": [str(label) for label in article.label_names if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.get("content_tag_ids", [])
for tag_id in article.content_tag_ids
if tag_id in content_tags
],
}
# Remove empty values
# remove empty values
metadata = {k: v for k, v in metadata.items() if v}
return new_author_mapping, Document(
id=f"article:{article['id']}",
return Document(
id=f"article:{article.id}",
sections=[
Section(
link=article.get("html_url"),
text=parse_html_page_basic(article["body"]),
)
Section(link=article.html_url, text=parse_html_page_basic(article.body))
],
source=DocumentSource.ZENDESK,
semantic_identifier=article["title"],
semantic_identifier=article.title,
doc_updated_at=update_time,
primary_owners=[author] if author else None,
primary_owners=[author],
metadata=metadata,
)
def _get_comment_text(
comment: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, str]:
author_id = comment.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
return new_author_mapping, comment_text
def _ticket_to_document(
ticket: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
default_subdomain: str,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
submitter_id = ticket.get("submitter")
if not submitter_id:
submitter = None
else:
submitter = (
author_map.get(submitter_id)
if submitter_id in author_map
else _fetch_author(client, submitter_id)
)
new_author_mapping = (
{submitter_id: submitter} if submitter_id and submitter else None
)
updated_at = ticket.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
metadata: dict[str, str | list[str]] = {}
if status := ticket.get("status"):
metadata["status"] = status
if priority := ticket.get("priority"):
metadata["priority"] = priority
if tags := ticket.get("tags"):
metadata["tags"] = tags
if ticket_type := ticket.get("type"):
metadata["ticket_type"] = ticket_type
# Fetch comments for the ticket
comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
comments = comments_data.get("comments", [])
comment_texts = []
for comment in comments:
new_author_mapping, comment_text = _get_comment_text(
comment, author_map, client
)
if new_author_mapping:
author_map.update(new_author_mapping)
comment_texts.append(comment_text)
comments_text = "\n\n".join(comment_texts)
subject = ticket.get("subject")
full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
ticket_url = ticket.get("url")
subdomain = (
ticket_url.split("//")[1].split(".zendesk.com")[0]
if ticket_url
else default_subdomain
)
ticket_display_url = (
f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.get('id')}"
)
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
sections=[Section(link=ticket_display_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=[submitter] if submitter else None,
metadata=metadata,
)
class ZendeskClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Zendesk Client is not set up, was load_credentials called?")
class ZendeskConnector(LoadConnector, PollConnector):
@@ -264,10 +66,44 @@ class ZendeskConnector(LoadConnector, PollConnector):
content_type: str = "articles",
) -> None:
self.batch_size = batch_size
self.content_type = content_type
self.subdomain = ""
# Fetch all tags ahead of time
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
self.content_type = content_type
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
self, subdomain: str, email: str, token: str, page_size: int = 30
) -> None:
# Construct the base URL
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
# Set up authentication
auth = (f"{email}/token", token)
# Set up pagination parameters
params = {"page[size]": page_size}
try:
while True:
# Make the GET request
response = requests.get(base_url, auth=auth, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
content_tag_list = data.get("records", [])
for tag in content_tag_list:
self.content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
else:
raise Exception(f"Error: {response.status_code}\n{response.text}")
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -276,23 +112,87 @@ class ZendeskConnector(LoadConnector, PollConnector):
.replace("https://", "")
.split(".zendesk.com")[0]
)
self.subdomain = subdomain
self.client = ZendeskClient(
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
self.zendesk_client = Zenpy(
subdomain=subdomain,
email=credentials["zendesk_email"],
token=credentials["zendesk_token"],
)
self._set_content_tags(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def _ticket_to_document(self, ticket: Ticket) -> Document:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
owner = None
if ticket.requester and ticket.requester.name and ticket.requester.email:
owner = [
BasicExpertInfo(
display_name=ticket.requester.name, email=ticket.requester.email
)
]
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
metadata: dict[str, str | list[str]] = {}
if ticket.status is not None:
metadata["status"] = ticket.status
if ticket.priority is not None:
metadata["priority"] = ticket.priority
if ticket.tags:
metadata["tags"] = ticket.tags
if ticket.type is not None:
metadata["ticket_type"] = ticket.type
# Fetch comments for the ticket
comments = self.zendesk_client.tickets.comments(ticket=ticket)
# Combine all comments into a single text
comments_text = "\n\n".join(
[
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
for comment in comments
if comment.body
]
)
# Combine ticket description and comments
description = (
ticket.description
if hasattr(ticket, "description") and ticket.description
else ""
)
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
# Extract subdomain from ticket.url
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
# Build the html url for the ticket
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
return Document(
id=f"zendesk_ticket_{ticket.id}",
sections=[Section(link=ticket_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=owner,
metadata=metadata,
)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
self.content_tags = _get_content_tag_mapping(self.client)
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
if self.content_type == "articles":
yield from self._poll_articles(start)
@@ -304,30 +204,26 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_articles(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
articles = _get_articles(self.client, start_time=int(start) if start else None)
# This one is built on the fly as there may be more many more authors than tags
author_map: dict[str, BasicExpertInfo] = {}
articles = (
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
if start is None
else self.zendesk_client.help_center.articles.incremental( # type: ignore
start_time=int(start)
)
)
doc_batch = []
for article in articles:
if (
article.get("body") is None
or article.get("draft")
article.body is None
or article.draft
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.get("label_names", [])
for label in article.label_names
)
):
continue
new_author_map, documents = _article_to_document(
article, self.content_tags, author_map, self.client
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
doc_batch.append(_article_to_document(article, self.content_tags))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
@@ -338,14 +234,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_tickets(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
author_map: dict[str, BasicExpertInfo] = {}
ticket_generator = _get_tickets(
self.client, start_time=int(start) if start else None
)
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
while True:
doc_batch = []
@@ -354,20 +246,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
ticket = next(ticket_generator)
# Check if the ticket status is deleted and skip it if so
if ticket.get("status") == "deleted":
if ticket.status == "deleted":
continue
new_author_map, documents = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
doc_batch.append(self._ticket_to_document(ticket))
if len(doc_batch) >= self.batch_size:
yield doc_batch
@@ -385,6 +267,7 @@ class ZendeskConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()

View File

@@ -57,10 +57,10 @@ from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -22,7 +22,6 @@ from danswer.db.models import User
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from ee.danswer.db.api_key import get_api_key_email_pattern
def get_default_admin_user_emails() -> list[str]:
@@ -36,16 +35,12 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()
def get_total_users_count(db_session: Session) -> int:
def get_total_users(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = (
db_session.query(User)
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
.count()
)
user_count = db_session.query(User).count()
invited_users = len(get_invited_users())
return user_count + invited_users

View File

@@ -388,7 +388,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_call))
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -474,7 +474,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_call: ToolCall | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
@@ -494,7 +494,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_call = tool_call
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
@@ -513,7 +513,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_call=tool_call,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
@@ -749,13 +749,14 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)

View File

@@ -10,10 +10,12 @@ from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Credential__UserGroup
@@ -440,7 +442,7 @@ def delete_google_drive_service_account_credentials(
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if credential.credential_json.get(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY):
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY):
db_session.delete(credential)
db_session.commit()

View File

@@ -37,10 +37,10 @@ from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -322,18 +322,11 @@ async def get_async_session_with_tenant(
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
"""
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id is None:
tenant_id = previous_tenant_id
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -342,35 +335,30 @@ def get_session_with_tenant(
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
yield session
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
def set_search_path_on_checkout(

View File

@@ -734,10 +734,9 @@ class IndexAttempt(Base):
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
# Nullable because in the past, we didn't allow swapping out embedding models live
search_settings_id: Mapped[int] = mapped_column(
ForeignKey("search_settings.id", ondelete="SET NULL"),
nullable=True,
ForeignKey("search_settings.id"),
nullable=False,
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -757,7 +756,7 @@ class IndexAttempt(Base):
"ConnectorCredentialPair", back_populates="index_attempts"
)
search_settings: Mapped[SearchSettings | None] = relationship(
search_settings: Mapped[SearchSettings] = relationship(
"SearchSettings", back_populates="index_attempts"
)
@@ -918,15 +917,10 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=False
)
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
# Update the relationship
message: Mapped["ChatMessage"] = relationship(
"ChatMessage",
back_populates="tool_call",
uselist=False,
"ChatMessage", back_populates="tool_calls"
)
@@ -1057,13 +1051,12 @@ class ChatMessage(Base):
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
)
tool_call: Mapped["ToolCall"] = relationship(
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
uselist=False,
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,

View File

@@ -14,6 +14,7 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -22,14 +23,7 @@ logger = setup_logger()
def check_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one.
Returns None if search settings did not change, or the old search settings if they
did change.
"""
old_search_settings = None
new index is done building. If so, swap the indices and expire the old one."""
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
@@ -49,9 +43,9 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
current_search_settings = get_current_search_settings(db_session)
now_old_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
search_settings=now_old_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
@@ -73,6 +67,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
old_search_settings = current_search_settings
return old_search_settings
if MULTI_TENANT:
return now_old_search_settings
return None

View File

@@ -13,7 +13,6 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
CSV = "csv"
class FileDescriptor(TypedDict):

View File

@@ -8,13 +8,12 @@ import requests
from sqlalchemy.orm import Session
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import FileDescriptor
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def load_chat_file(
@@ -53,11 +52,11 @@ def load_all_chat_files(
return files
def save_file_from_url(url: str, tenant_id: str) -> str:
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_with_tenant(tenant_id) as db_session:
with get_session_context_manager() as db_session:
response = requests.get(url)
response.raise_for_status()
@@ -76,10 +75,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
def save_files_from_urls(urls: list[str]) -> list[str]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url, tenant_id)) for url in urls
(save_file_from_url, (url,)) for url in urls
]
# Must pass in tenant_id here, since this is called by multithreading
return run_functions_tuples_in_parallel(funcs)

View File

@@ -16,9 +16,9 @@ from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -1,44 +1,72 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from langchain_core.messages import HumanMessage
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
from danswer.llm.answering.prompts.citations_prompt import (
build_citations_system_message,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
CitationResponseHandler,
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
from danswer.llm.answering.stream_processing.citation_processing import (
build_citation_processor,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
QuotesResponseHandler,
from danswer.llm.answering.stream_processing.quotes_processing import (
build_quotes_processor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
from danswer.tools.force import filter_tools_for_force_tool_use
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.images.prompt import build_image_generation_user_prompt
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@@ -46,9 +74,29 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_answer_stream_processor(
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
)
if answer_style_configs.quotes_config:
return build_quotes_processor(
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
)
raise RuntimeError("Not implemented yet")
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
logger = setup_logger()
class Answer:
def __init__(
self,
@@ -88,6 +136,8 @@ class Answer:
self.tools = tools or []
self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
# used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
@@ -112,141 +162,335 @@ class Answer:
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
self.using_tool_calling_llm = (
explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
) -> None:
if self.answer_style_config.citation_config:
prompt_builder.update_system_prompt(
build_citations_system_message(self.prompt_config)
)
and not skip_explicit_tool_calling
)
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool = next(
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
)
if tool is None:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
logger.info(
f"Forcefully using tool='{tool.name}'"
+ (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args is not None
else ""
)
)
return [tool]
def _handle_specified_tool_call(
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
) -> AnswerStream:
current_llm_call = llm_calls[-1]
# make a dummy tool handler
tool_handler = ToolResponseHandler([tool])
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
else:
raise RuntimeError("Tool call handler did not return a new LLM call")
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
and current_llm_call.force_use_tool.args is not None
):
tool_name, tool_args = (
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = next(
(t for t in current_llm_call.tools if t.name == tool_name), None
)
if not tool:
raise RuntimeError(f"Tool '{tool_name}' not found")
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# special pre-logic for non-tool calling LLM case
if not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
prompt_builder.update_user_prompt(
build_citations_user_message(
question=self.question,
prompt_config=self.prompt_config,
context_docs=final_context_documents,
files=self.latest_query_files,
all_doc_useful=(
self.answer_style_config.citation_config.all_docs_useful
if self.answer_style_config.citation_config
else False
),
history_message=self.single_message_history or "",
)
)
elif self.answer_style_config.quotes_config:
prompt_builder.update_user_prompt(
build_quotes_user_message(
question=self.question,
context_docs=final_context_documents,
history_str=self.single_message_history or "",
prompt=self.prompt_config,
)
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
# answer, which is a no-no!
if (
self.skip_gen_ai_answer_generation
and not current_llm_call.force_use_tool.force_use
):
def _raw_output_for_explicit_tool_calling_llms(
self,
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
content="",
)
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
# if tool calling is supported, first try the raw message
# to see if we don't need to use any tools
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
if tool_call_chunk is None:
tool_call_chunk = message
else:
tool_call_chunk += message # type: ignore
else:
if message.content:
if self.is_cancelled:
return
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(
stop_reason=StreamStopReason.CONTEXT_LENGTH
)
if not tool_call_chunk:
return # no tool call needed
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
yield from tool_runner.tool_responses()
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from self._process_llm_stream(
prompt=prompt,
# as of now, we don't support multiple tool calls in sequence, which is why
# we don't need to pass this in here
# tools=[tool.tool_definition() for tool in self.tools],
)
return
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
# This method processes the LLM stream and yields the content or stop information
def _process_llm_stream(
self,
prompt: Any,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[str | StreamStopInfo]:
for message in self.llm.stream(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk):
if message.content:
if self.is_cancelled:
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield cast(str, message.content)
search_result = SearchTool.get_search_result(current_llm_call) or []
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)
answer_handler: AnswerResponseHandler
if self.answer_style_config.citation_config:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
def _raw_output_for_non_explicit_tool_calling_llms(
self,
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
elif self.answer_style_config.quotes_config:
answer_handler = QuotesResponseHandler(
context_docs=search_result,
if not tool:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
if self.skip_gen_ai_answer_generation:
raise ValueError(
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
img_urls = [img.url for img in img_generation_response]
yield response
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=img_urls,
)
)
else:
raise ValueError("No answer style config provided")
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final = tool_runner.tool_final_result()
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
yield final
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build()
# DEBUG: good breakpoint
stream = self.llm.stream(
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -254,30 +498,94 @@ class Answer:
yield from self._processed_stream
return
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
),
message_history=self.message_history,
llm_config=self.llm.config,
single_message_history=self.single_message_history,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
llm_call = LLMCall(
prompt_builder=prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
output_generator = (
self._raw_output_for_explicit_tool_calling_llms()
if explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
)
and not self.skip_explicit_tool_calling
else self._raw_output_for_non_explicit_tool_calling_llms()
)
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
) -> AnswerStream:
message = None
# special things we need to keep track of for the SearchTool
# raw results that will be displayed to the user
search_results: list[LlmDoc] | None = None
# processed docs to feed into the LLM
final_context_docs: list[LlmDoc] | None = None
for message in stream:
if isinstance(message, ToolCallKickoff) or isinstance(
message, ToolCallFinalResult
):
yield message
elif isinstance(message, ToolResponse):
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
# We don't need to run section merging in this flow, this variable is only used
# below to specify the ordering of the documents for the purpose of matching
# citations to the right search documents. The deduplication logic is more lightweight
# there and we don't need to do it twice
search_results = [
llm_doc_from_inference_section(section)
for section in cast(
SearchResponseSummary, message.response
).top_sections
]
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_docs = cast(list[LlmDoc], message.response)
yield message
elif (
message.id == SEARCH_DOC_CONTENT_ID
and not self._return_contexts
):
continue
yield message
else:
# assumes all tool responses will come first, then the final answer
break
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
stream_stop_info = None
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
for item in itertools.chain([message], stream):
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
# this should never happen, but we're seeing weird behavior here so handling for now
if not isinstance(item, str):
logger.error(
f"Received non-string item in answer stream: {item}. Skipping."
)
continue
yield item
yield from process_answer_stream_fn(_stream())
if stream_stop_info:
yield stream_stop_info
processed_stream = []
for processed_packet in self._get_response([llm_call]):
for processed_packet in _process_stream(output_generator):
processed_stream.append(processed_packet)
yield processed_packet
@@ -301,6 +609,7 @@ class Answer:
return citations
@property
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True

View File

@@ -1,84 +0,0 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| DanswerQuotes
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True
class LLMResponseHandlerManager:
def __init__(
self,
tool_handler: "ToolResponseHandler",
answer_handler: "AnswerResponseHandler",
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler
self.answer_handler = answer_handler
self.is_cancelled = is_cancelled
def handle_llm_response(
self,
stream: Iterator[BaseMessage],
) -> Generator[ResponsePart, None, None]:
all_messages: list[BaseMessage] = []
for message in stream:
if self.is_cancelled():
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
return
# tool handler doesn't do anything until the full message is received
# NOTE: still need to run list() to get this to run
list(self.tool_handler.handle_response_part(message, all_messages))
yield from self.answer_handler.handle_response_part(message, all_messages)
all_messages.append(message)
# potentially give back all info on the selected tool call + its result
yield from self.tool_handler.handle_response_part(None, all_messages)
yield from self.answer_handler.handle_response_part(None, all_messages)
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
return self.tool_handler.next_llm_call(llm_call)

View File

@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
tool_calls: list[ToolCallFinalResult]
@classmethod
def from_chat_message(
@@ -51,13 +51,14 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -12,12 +12,12 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
from danswer.tools.message import ToolCallSummary
def default_build_system_message(
@@ -54,14 +54,18 @@ def default_build_user_message(
class AnswerPromptBuilder:
def __init__(
self,
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
single_message_history: str | None = None,
self, message_history: list[PreviousMessage], llm_config: LLMConfig
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
@@ -70,24 +74,6 @@ class AnswerPromptBuilder:
Callable[[str], list[int]], llm_tokenizer.encode
)
self.raw_message_history = message_history
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
# for cases where like the QA flow where we want to condense the chat history
# into a single message rather than a sequence of User / Assistant messages
self.single_message_history = single_message_history
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
@@ -99,21 +85,18 @@ class AnswerPromptBuilder:
)
def update_user_prompt(self, user_message: HumanMessage) -> None:
if not user_message:
self.user_message_and_token_cnt = None
return
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
def append_message(self, message: BaseMessage) -> None:
"""Append a new message to the message history."""
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
self.new_messages_and_token_cnts.append((message, token_count))
def get_user_message_content(self) -> str:
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
return query
def build(self) -> list[BaseMessage]:
def build(
self, tool_call_summary: ToolCallSummary | None = None
) -> list[BaseMessage]:
if not self.user_message_and_token_cnt:
raise ValueError("User message must be set before building prompt")
@@ -130,8 +113,25 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
if tool_call_summary:
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_request,
check_message_tokens(
tool_call_summary.tool_call_request,
self.llm_tokenizer_encode_func,
),
)
)
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_result,
check_message_tokens(
tool_call_summary.tool_call_result,
self.llm_tokenizer_encode_func,
),
)
)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -6,6 +6,7 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MA
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
@@ -13,7 +14,6 @@ from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
@@ -132,9 +132,10 @@ def build_citations_system_message(
def build_citations_user_message(
message: HumanMessage,
question: str,
prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
files: list[InMemoryChatFile],
all_doc_useful: bool,
history_message: str = "",
) -> HumanMessage:
@@ -148,7 +149,6 @@ def build_citations_user_message(
if history_message
else ""
)
query, img_urls = message_to_prompt_and_imgs(message)
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
@@ -158,22 +158,20 @@ def build_citations_user_message(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=query,
user_query=question,
history_block=history_block,
)
else:
# if no context docs provided, assume we're in the tool calling flow
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
task_prompt=task_prompt_with_reminder,
user_query=query,
user_query=question,
history_block=history_block,
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
if img_urls
else user_prompt
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg

View File

@@ -5,7 +5,6 @@ from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
@@ -76,7 +75,7 @@ def _build_strong_llm_quotes_prompt(
def build_quotes_user_message(
message: HumanMessage,
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
@@ -87,10 +86,28 @@ def build_quotes_user_message(
else _build_strong_llm_quotes_prompt
)
query, _ = message_to_prompt_and_imgs(message)
return prompt_builder(
question=query,
question=question,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
)
def build_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
if QA_PROMPT_OVERRIDE == "weak"
else _build_strong_llm_quotes_prompt
)
return prompt_builder(
question=question,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,

View File

@@ -19,7 +19,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
from danswer.tools.search.search_utils import section_to_dict
from danswer.utils.logger import setup_logger

View File

@@ -1,91 +0,0 @@
import abc
from collections.abc import Generator
from langchain_core.messages import BaseMessage
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.answering.stream_processing.citation_processing import (
CitationProcessor,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
class AnswerResponseHandler(abc.ABC):
@abc.abstractmethod
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
raise NotImplementedError
class DummyAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
# This is a dummy handler that returns nothing
yield from []
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
# Process the new content through the citation processor
yield from self.citation_processor.process_token(content)
class QuotesResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.quotes_processor = QuotesProcessor(
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self.quotes_processor.process_token(None)
return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
yield from self.quotes_processor.process_token(content)

View File

@@ -1,10 +1,12 @@
import re
from collections.abc import Generator
from collections.abc import Iterator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
@@ -17,104 +19,128 @@ def in_code_block(llm_text: str) -> bool:
return count % 2 != 0
class CitationProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
self.curr_segment = ""
self.cited_inds: set[int] = set()
self.hold = ""
self.current_citations: list[int] = []
self.past_cite_count = 0
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
"""
Key aspects:
def process_token(
self, token: str | None
) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]:
# None -> end of stream
if token is None:
yield DanswerAnswerPiece(answer_piece=self.curr_segment)
return
1. Stream Processing:
- Processes tokens one by one, allowing for real-time handling of large texts.
if self.stop_stream:
next_hold = self.hold + token
if self.stop_stream in next_hold:
return
if next_hold == self.stop_stream[: len(next_hold)]:
self.hold = next_hold
return
2. Citation Detection:
- Uses regex to find citations in the format [number].
- Example: [1], [2], etc.
3. Citation Mapping:
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
4. Citation Formatting:
- Replaces citations with properly formatted versions.
- Adds links if available: [[1]](https://example.com)
- Handles cases where links are not available: [[1]]()
5. Duplicate Handling:
- Skips consecutive citations of the same document to avoid redundancy.
6. Output Generation:
- Yields DanswerAnswerPiece objects for regular text.
- Yields CitationInfo objects for each unique citation encountered.
7. Context Awareness:
- Uses context_docs to access document information for citations.
This function effectively processes a stream of text, identifies and reformats citations,
and provides both the processed text and citation information as output.
"""
order_mapping = doc_id_to_rank_map.order_mapping
llm_out = ""
max_citation_num = len(context_docs)
citation_order = []
curr_segment = ""
cited_inds = set()
hold = ""
raw_out = ""
current_citations: list[int] = []
past_cite_count = 0
for raw_token in tokens:
raw_out += raw_token
if stop_stream:
next_hold = hold + raw_token
if stop_stream in next_hold:
break
if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold
continue
token = next_hold
self.hold = ""
hold = ""
else:
token = raw_token
self.curr_segment += token
self.llm_out += token
curr_segment += token
llm_out += token
# Handle code blocks without language tags
if "`" in self.curr_segment:
if self.curr_segment.endswith("`"):
return
elif "```" in self.curr_segment:
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
if "`" in curr_segment:
if curr_segment.endswith("`"):
continue
elif "```" in curr_segment:
piece_that_comes_after = curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(llm_out):
curr_segment = curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
citations_found = list(re.finditer(citation_pattern, curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
# `past_cite_count`: number of characters since past citation
# 5 to ensure a citation hasn't occured
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
current_citations = []
result = "" # Initialize result here
if citations_found and not in_code_block(self.llm_out):
if citations_found and not in_code_block(llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(citation.group(1))
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
real_citation_num = self.order_mapping[context_llm_doc.document_id]
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[numerical_value - 1]
real_citation_num = order_mapping[context_llm_doc.document_id]
if real_citation_num not in self.citation_order:
self.citation_order.append(real_citation_num)
if real_citation_num not in citation_order:
citation_order.append(real_citation_num)
target_citation_num = (
self.citation_order.index(real_citation_num) + 1
)
target_citation_num = citation_order.index(real_citation_num) + 1
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
if target_citation_num in current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
curr_segment = (
curr_segment[: length_to_add + start]
+ curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
# by allowing it to generate citations on its own.
if curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
context_llm_doc = context_docs[doc_id - 1]
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
@@ -124,57 +150,75 @@ class CitationProcessor:
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
# Will continue attempt on next loops
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
self.curr_segment = (
self.curr_segment[: start + length_to_add]
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ self.curr_segment[end + length_to_add :]
+ curr_segment[end + length_to_add :]
)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
past_cite_count = len(llm_out)
current_citations.append(target_citation_num)
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
+ curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
length_to_add += len(curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
+ curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
length_to_add += len(curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]
self.curr_segment = self.curr_segment[last_citation_end:]
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
curr_segment = curr_segment[last_citation_end:]
if possible_citation_found:
continue
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if not possible_citation_found:
result += self.curr_segment
self.curr_segment = ""
if curr_segment:
yield DanswerAnswerPiece(answer_piece=curr_segment)
if result:
yield DanswerAnswerPiece(answer_piece=result)
def build_citation_processor(
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
) -> StreamProcessor:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from extract_citations_from_stream(
tokens=tokens,
context_docs=context_docs,
doc_id_to_rank_map=doc_id_to_rank_map,
)
return stream_processor

View File

@@ -1,11 +1,14 @@
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
import regex
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
@@ -154,7 +157,7 @@ def separate_answer_quotes(
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
def _process_answer(
def process_answer(
answer_raw: str,
docs: list[LlmDoc],
is_json_prompt: bool = True,
@@ -192,7 +195,7 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
def _extract_quotes_from_completed_token_stream(
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
) -> DanswerQuotes:
answer, quotes = _process_answer(model_output, context_docs, is_json_prompt)
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
if answer:
logger.notice(answer)
elif model_output:
@@ -201,101 +204,94 @@ def _extract_quotes_from_completed_token_stream(
return quotes
class QuotesProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.context_docs = context_docs
self.is_json_prompt = is_json_prompt
def process_model_tokens(
tokens: Iterator[str],
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Used in the streaming case to process the model output
into an Answer and Quotes
self.found_answer_start = False if is_json_prompt else True
self.found_answer_end = False
self.hold_quote = ""
self.model_output = ""
self.hold = ""
Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
quote_pat = f"\n{QUOTE_PAT}"
# Sometimes worse model outputs new line instead of :
quote_loose = f"\n{quote_pat[:-1]}\n"
# Sometime model outputs two newlines before quote section
quote_pat_full = f"\n{quote_pat}"
model_output: str = ""
found_answer_start = False if is_json_prompt else True
found_answer_end = False
hold_quote = ""
def process_token(
self, token: str | None
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
# None -> end of stream
if token is None:
if self.model_output:
yield _extract_quotes_from_completed_token_stream(
model_output=self.model_output,
context_docs=self.context_docs,
is_json_prompt=self.is_json_prompt,
)
return
for token in tokens:
model_previous = model_output
model_output += token
model_previous = self.model_output
self.model_output += token
if not self.found_answer_start:
m = answer_pattern.search(self.model_output)
if not found_answer_start:
m = answer_pattern.search(model_output)
if m:
self.found_answer_start = True
found_answer_start = True
# Prevent heavy cases of hallucinations
if self.is_json_prompt and len(self.model_output) > 70:
# Prevent heavy cases of hallucinations where model is never providing a JSON
# We want to quickly update the user - not stream forever
if is_json_prompt and len(model_output) > 70:
logger.warning("LLM did not produce json as prompted")
self.found_answer_end = True
return
found_answer_end = True
continue
remaining = self.model_output[m.end() :]
# Look for an unescaped quote, which means the answer is entirely contained
# in this token e.g. if the token is `{"answer": "blah", "qu`
quote_indices = [i for i, char in enumerate(remaining) if char == '"']
for quote_idx in quote_indices:
# Check if quote is escaped by counting backslashes before it
num_backslashes = 0
pos = quote_idx - 1
while pos >= 0 and remaining[pos] == "\\":
num_backslashes += 1
pos -= 1
# If even number of backslashes, quote is not escaped
if num_backslashes % 2 == 0:
yield DanswerAnswerPiece(answer_piece=remaining[:quote_idx])
return
# If no unescaped quote found, yield the remaining string
remaining = model_output[m.end() :]
if len(remaining) > 0:
yield DanswerAnswerPiece(answer_piece=remaining)
return
continue
if self.found_answer_start and not self.found_answer_end:
if self.is_json_prompt and _stream_json_answer_end(model_previous, token):
self.found_answer_end = True
if found_answer_start and not found_answer_end:
if is_json_prompt and _stream_json_answer_end(model_previous, token):
found_answer_end = True
# return the remaining part of the answer e.g. token might be 'd.", ' and we should yield 'd.'
if token:
try:
answer_token_section = token.index('"')
yield DanswerAnswerPiece(
answer_piece=self.hold_quote + token[:answer_token_section]
answer_piece=hold_quote + token[:answer_token_section]
)
except ValueError:
logger.error("Quotation mark not found in token")
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
yield DanswerAnswerPiece(answer_piece=None)
return
elif not self.is_json_prompt:
quote_pat = f"\n{QUOTE_PAT}"
quote_loose = f"\n{quote_pat[:-1]}\n"
quote_pat_full = f"\n{quote_pat}"
if (
quote_pat in self.hold_quote + token
or quote_loose in self.hold_quote + token
):
self.found_answer_end = True
continue
elif not is_json_prompt:
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
return
if self.hold_quote + token in quote_pat_full:
self.hold_quote += token
return
continue
if hold_quote + token in quote_pat_full:
hold_quote += token
continue
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
hold_quote = ""
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
self.hold_quote = ""
logger.debug(f"Raw Model QnA Output: {model_output}")
yield _extract_quotes_from_completed_token_stream(
model_output=model_output,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
def build_quotes_processor(
context_docs: list[LlmDoc], is_json_prompt: bool
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
return stream_processor

View File

@@ -1,207 +0,0 @@
from collections.abc import Generator
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.interfaces import LLM
from danswer.tools.force import ForceUseTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.utils.logger import setup_logger
logger = setup_logger()
class ToolResponseHandler:
def __init__(self, tools: list[Tool]):
self.tools = tools
self.tool_call_chunk: AIMessageChunk | None = None
self.tool_call_requests: list[ToolCall] = []
self.tool_runner: ToolRunner | None = None
self.tool_call_summary: ToolCallSummary | None = None
self.tool_kickoff: ToolCallKickoff | None = None
self.tool_responses: list[ToolResponse] = []
self.tool_final_result: ToolCallFinalResult | None = None
@classmethod
def get_tool_call_for_non_tool_calling_llm(
cls, llm_call: LLMCall, llm: LLM
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
(
t
for t in llm_call.tools
if t.name == llm_call.force_use_tool.tool_name
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
)
tool_args = (
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(llm_call.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
return
self.tool_call_requests = self.tool_call_chunk.tool_calls
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in self.tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
continue
else:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
if selected_tool and selected_tool_call_request:
break
if not selected_tool or not selected_tool_call_request:
return
logger.info(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
self.tool_kickoff = self.tool_runner.kickoff()
yield self.tool_kickoff
for response in self.tool_runner.tool_responses():
self.tool_responses.append(response)
yield response
self.tool_final_result = self.tool_runner.tool_final_result()
yield self.tool_final_result
self.tool_call_summary = ToolCallSummary(
tool_call_request=self.tool_call_chunk,
tool_call_result=build_tool_message(
selected_tool_call_request, self.tool_runner.tool_message_content()
),
)
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self._handle_tool_call()
if isinstance(response_item, AIMessageChunk) and (
response_item.tool_call_chunks or response_item.tool_calls
):
if self.tool_call_chunk is None:
self.tool_call_chunk = response_item
else:
self.tool_call_chunk += response_item # type: ignore
return
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
if (
self.tool_runner is None
or self.tool_call_summary is None
or self.tool_kickoff is None
or self.tool_final_result is None
):
return None
tool_runner = self.tool_runner
new_prompt_builder = tool_runner.tool.build_next_prompt(
prompt_builder=current_llm_call.prompt_builder,
tool_call_summary=self.tool_call_summary,
tool_responses=self.tool_responses,
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
)
return LLMCall(
prompt_builder=new_prompt_builder,
tools=[], # for now, only allow one tool call per response
force_use_tool=ForceUseTool(
force_use=False,
tool_name="",
args=None,
),
files=current_llm_call.files,
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
tool_call_info=[
self.tool_kickoff,
*self.tool_responses,
self.tool_final_result,
],
)

View File

@@ -83,10 +83,8 @@ def _convert_litellm_message_to_langchain_message(
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
]
if tool_calls
else [],
for tool_call in (tool_calls if tool_calls else [])
],
)
elif role == "system":
return SystemMessage(content=content)

View File

@@ -1,4 +1,3 @@
import io
import json
from collections.abc import Callable
from collections.abc import Iterator
@@ -8,7 +7,6 @@ from typing import TYPE_CHECKING
from typing import Union
import litellm # type: ignore
import pandas as pd
import tiktoken
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
@@ -137,18 +135,6 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts
def _process_csv_file(file: InMemoryChatFile) -> str:
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string()
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"
if file.filename
else "CSV FILE (NO NAME PROVIDED):\n"
)
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
@@ -159,26 +145,16 @@ def _build_content(
if files
else None
)
csv_files = (
[file for file in files if file.file_type == ChatFileType.CSV]
if files
else None
)
if not text_files and not csv_files:
if not text_files:
return message
final_message_with_files = "FILES:\n\n"
for file in text_files or []:
for file in text_files:
file_content = file.content.decode("utf-8")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
)
for file in csv_files or []:
final_message_with_files += _process_csv_file(file)
final_message_with_files += message
return final_message_with_files
@@ -227,28 +203,6 @@ def build_content_with_imgs(
)
def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]:
if isinstance(message.content, str):
return message.content, []
imgs = []
texts = []
for part in message.content:
if isinstance(part, dict):
if part.get("type") == "image_url":
img_url = part.get("image_url", {}).get("url")
if img_url:
imgs.append(img_url)
elif part.get("type") == "text":
text = part.get("text")
if text:
texts.append(text)
else:
texts.append(part)
return "".join(texts), imgs
def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:

View File

@@ -52,16 +52,12 @@ from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephr
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -206,33 +202,30 @@ def stream_answer_objects(
max_tokens=max_document_tokens,
)
answer_config = AnswerStyleConfig(
citation_config=CitationConfig() if use_citations else None,
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=document_pruning_config,
)
search_tool = SearchTool(
db_session=db_session,
user=user,
evaluation_type=(
LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type
),
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
)
answer_config = AnswerStyleConfig(
citation_config=CitationConfig() if use_citations else None,
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=document_pruning_config,
)
answer = Answer(
question=query_msg.message,
answer_style_config=answer_config,

View File

@@ -1,72 +0,0 @@
import redis
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
from danswer.redis.redis_pool import get_redis_client
class RedisConnector:
"""Composes several classes to simplify interacting with a connector and its
associated background tasks / associated redis interactions."""
def __init__(self, tenant_id: str | None, id: int) -> None:
self.tenant_id: str | None = tenant_id
self.id: int = id
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
self.tenant_id, self.id, search_settings_id, self.redis
)
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id

View File

@@ -1,97 +0,0 @@
import time
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.redis.redis_object_helper import RedisObjectHelper
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)

View File

@@ -1,145 +0,0 @@
import time
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime
class RedisConnectorDelete:
"""Manages interactions with redis for deletion tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectordeletion"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def get_remaining(self) -> int:
# todo: move into fence
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
@property
def payload(self) -> RedisConnectorDeletionFenceData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(self, payload: RedisConnectorDeletionFenceData | None) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
def _generate_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "connectordeletion_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
return f"{self.PREFIX}_{self.id}_{uuid4()}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
custom_task_id = self._generate_task_id()
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
self.redis.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=self.tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorDelete.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,146 +0,0 @@
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None
class RedisConnectorIndex:
"""Manages interactions with redis for indexing tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorindexing"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorindexing_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorindexing_generator_complete
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
def __init__(
self,
tenant_id: str | None,
id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.search_settings_id = search_settings_id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}/{search_settings_id}"
self.generator_progress_key = (
f"{self.GENERATOR_PROGRESS_PREFIX}_{id}/{search_settings_id}"
)
self.generator_complete_key = (
f"{self.GENERATOR_COMPLETE_PREFIX}_{id}/{search_settings_id}"
)
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
return f"{cls.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
def generate_generator_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "connectorindexing+generator_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
return f"{self.GENERATOR_TASK_PREFIX}_{self.id}/{self.search_settings_id}_{uuid4()}"
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
@property
def payload(self) -> RedisConnectorIndexingFenceData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorIndexingFenceData | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_progress(self) -> int | None:
"""Returns None if the key doesn't exist. The"""
# TODO: move into fence?
bytes = self.redis.get(self.generator_progress_key)
if bytes is None:
return None
progress = int(cast(int, bytes))
return progress
def get_completion(self) -> int | None:
# TODO: move into fence?
bytes = self.redis.get(self.generator_complete_key)
if bytes is None:
return None
status = int(cast(int, bytes))
return status
def reset(self) -> None:
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.fence_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,171 +0,0 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
class RedisConnectorPrune:
"""Manages interactions with redis for pruning tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorpruning"
FENCE_PREFIX = f"{PREFIX}_fence"
# phase 1 - geneartor task and progress signals
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpruning+generator
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorpruning_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorpruning_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_remaining(self) -> int:
# todo: move into fence
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
def get_active_task_count(self) -> int:
"""Count of active pruning tasks"""
count = 0
for key in self.redis.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
count += 1
return count
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, value: bool) -> None:
if not value:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
pruning tasks to be processed ... just after the generator completes."""
fence_bytes = self.redis.get(self.generator_complete_key)
if fence_bytes is None:
return None
fence_int = cast(int, fence_bytes)
return fence_int
@generator_complete.setter
def generator_complete(self, payload: int | None) -> None:
"""Set the payload to an int to set the fence, otherwise if None it will
be deleted"""
if payload is None:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
self,
documents_to_prune: set[str],
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session)
if not cc_pair:
return None
for doc_id in documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
self.redis.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=self.tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,34 +0,0 @@
import redis
class RedisConnectorStop:
"""Manages interactions with redis for stop signaling. Should only be accessed
through RedisConnector."""
FENCE_PREFIX = "connectorstop_fence"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id: int = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, value: bool) -> None:
if not value:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
@staticmethod
def reset_all(r: redis.Redis) -> None:
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,99 +0,0 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.document_set import construct_document_select_by_docset
from danswer.redis.redis_object_helper import RedisObjectHelper
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, payload: int | None) -> None:
if payload is None:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload)
@property
def payload(self) -> int | None:
bytes = self.redis.get(self.fence_key)
if bytes is None:
return None
progress = int(cast(int, bytes))
return progress
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,91 +0,0 @@
from abc import ABC
from abc import abstractmethod
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.redis.redis_pool import get_redis_client
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: str):
self._tenant_id: str | None = tenant_id
self._id: str = id
self.redis = get_redis_client(tenant_id=tenant_id)
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass

View File

@@ -1,112 +0,0 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.redis.redis_object_helper import RedisObjectHelper
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, payload: int | None) -> None:
if payload is None:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload)
@property
def payload(self) -> int | None:
bytes = self.redis.get(self.fence_key)
if bytes is None:
return None
progress = int(cast(int, bytes))
return progress
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -10,7 +10,6 @@ from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.server.tenants.access import control_plane_dep
@@ -101,7 +100,6 @@ def check_router_auth(
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):
found_auth = True
break

View File

@@ -11,6 +11,8 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
@@ -37,7 +39,6 @@ from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCStatusUpdateRequest
@@ -96,6 +97,8 @@ def get_cc_pair_full_info(
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CCPairFullInfo:
r = get_redis_client(tenant_id=tenant_id)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session, user, get_editable=False
)
@@ -131,9 +134,9 @@ def get_cc_pair_full_info(
)
search_settings = get_current_search_settings(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings.id)
rci = RedisConnectorIndexing(
cc_pair_id=cc_pair_id, search_settings_id=search_settings.id
)
return CCPairFullInfo.from_models(
cc_pair_model=cc_pair,
@@ -150,7 +153,7 @@ def get_cc_pair_full_info(
),
num_docs_indexed=documents_indexed,
is_editable_for_current_user=is_editable_for_current_user,
indexing=redis_connector_index.fenced,
indexing=rci.is_indexing(r),
)
@@ -260,9 +263,8 @@ def prune_cc_pair(
)
r = get_redis_client(tenant_id=tenant_id)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(r):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Pruning task already in progress.",

View File

@@ -9,13 +9,13 @@ from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from google.oauth2.credentials import Credentials # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
from danswer.background.celery.versioned_apps.primary import app as primary_app
@@ -35,7 +35,6 @@ from danswer.connectors.gmail.connector_auth import (
)
from danswer.connectors.gmail.connector_auth import upsert_google_app_gmail_cred
from danswer.connectors.google_drive.connector_auth import build_service_account_creds
from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.connector_auth import delete_google_app_cred
from danswer.connectors.google_drive.connector_auth import delete_service_account_key
from danswer.connectors.google_drive.connector_auth import get_auth_url
@@ -44,13 +43,13 @@ from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_authorized_user,
)
from danswer.connectors.google_drive.connector_auth import get_service_account_key
from danswer.connectors.google_drive.connector_auth import GOOGLE_DRIVE_SCOPES
from danswer.connectors.google_drive.connector_auth import (
update_credential_access_tokens,
)
from danswer.connectors.google_drive.connector_auth import upsert_google_app_cred
from danswer.connectors.google_drive.connector_auth import upsert_service_account_key
from danswer.connectors.google_drive.connector_auth import verify_csrf
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.db.connector import create_connector
from danswer.db.connector import delete_connector
from danswer.db.connector import fetch_connector_by_id
@@ -84,7 +83,6 @@ from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
@@ -296,7 +294,7 @@ def upsert_service_account_credential(
try:
credential_base = build_service_account_creds(
DocumentSource.GOOGLE_DRIVE,
primary_admin_email=service_account_credential_request.google_drive_primary_admin,
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -322,7 +320,7 @@ def upsert_gmail_service_account_credential(
try:
credential_base = build_service_account_creds(
DocumentSource.GMAIL,
primary_admin_email=service_account_credential_request.gmail_delegated_user,
delegated_user_email=service_account_credential_request.gmail_delegated_user,
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -350,14 +348,27 @@ def check_drive_tokens(
return AuthStatus(authenticated=False)
token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
google_drive_creds = get_google_drive_creds_for_authorized_user(
token_json_str=token_json_str,
scopes=GOOGLE_DRIVE_SCOPES,
token_json_str=token_json_str
)
if google_drive_creds is None:
return AuthStatus(authenticated=False)
return AuthStatus(authenticated=True)
@router.get("/admin/connector/google-drive/authorize/{credential_id}")
def admin_google_drive_auth(
response: Response, credential_id: str, _: User = Depends(current_admin_user)
) -> AuthUrl:
# set a cookie that we can read in the callback (used for `verify_csrf`)
response.set_cookie(
key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME,
value=credential_id,
httponly=True,
max_age=600,
)
return AuthUrl(auth_url=get_auth_url(credential_id=int(credential_id)))
@router.post("/admin/connector/file/upload")
def upload_files(
files: list[UploadFile],
@@ -486,10 +497,12 @@ def get_connector_indexing_status(
) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = []
r = get_redis_client(tenant_id=tenant_id)
# NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like
# connector or credential may be None.
# Additional checks are done to make sure the connector and credential still exist.
# Additional checks are done to make sure the connector and credential still exists.
# TODO: make this one query ... possibly eager load or wrap in a read transaction
# to avoid the complexity of trying to error check throughout the function
cc_pairs = get_connector_credential_pairs(
@@ -556,9 +569,8 @@ def get_connector_indexing_status(
in_progress = False
if search_settings:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
if r.exists(rci.fence_key):
in_progress = True
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
@@ -939,11 +951,10 @@ def google_drive_callback(
)
credential_id = int(credential_id_cookie)
verify_csrf(credential_id, callback.state)
credentials: Credentials | None = update_credential_access_tokens(
callback.code, credential_id, user, db_session
)
if credentials is None:
if (
update_credential_access_tokens(callback.code, credential_id, user, db_session)
is None
):
raise HTTPException(
status_code=500, detail="Unable to fetch Google Drive access tokens"
)

View File

@@ -81,6 +81,18 @@ def get_cc_source_full_info(
]
@router.get("/credential/{id}")
def list_credentials_by_id(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[CredentialSnapshot]:
credentials = fetch_credentials(db_session=db_session, user=user)
return [
CredentialSnapshot.from_credential_db_model(credential)
for credential in credentials
]
@router.delete("/admin/credential/{credential_id}")
def delete_credential_by_id_admin(
credential_id: int,

View File

@@ -377,16 +377,16 @@ class GoogleServiceAccountKey(BaseModel):
class GoogleServiceAccountCredentialRequest(BaseModel):
google_drive_primary_admin: str | None = None # email of user to impersonate
google_drive_delegated_user: str | None = None # email of user to impersonate
gmail_delegated_user: str | None = None # email of user to impersonate
@model_validator(mode="after")
def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest":
if (self.google_drive_primary_admin is None) == (
if (self.google_drive_delegated_user is None) == (
self.gmail_delegated_user is None
):
raise ValueError(
"Exactly one of google_drive_primary_admin or gmail_delegated_user must be set"
"Exactly one of google_drive_delegated_user or gmail_delegated_user must be set"
)
return self

View File

@@ -9,7 +9,7 @@ from danswer.db.models import StarterMessage
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.prompt.models import PromptSnapshot
from danswer.server.features.tool.models import ToolSnapshot
from danswer.server.features.tool.api import ToolSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.utils.logger import setup_logger

View File

@@ -18,16 +18,10 @@ from danswer.db.tools import update_tool
from danswer.server.features.tool.models import CustomToolCreate
from danswer.server.features.tool.models import CustomToolUpdate
from danswer.server.features.tool.models import ToolSnapshot
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
from danswer.tools.tool_implementations.custom.openapi_parsing import (
openapi_to_method_specs,
)
from danswer.tools.tool_implementations.custom.openapi_parsing import (
validate_openapi_schema,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.utils import is_image_generation_available
router = APIRouter(prefix="/tool")

View File

@@ -57,7 +57,6 @@ class UserInfo(BaseModel):
oidc_expiry: datetime | None = None
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
is_cloud_superuser: bool = False
organization_name: str | None = None
@classmethod
@@ -66,7 +65,6 @@ class UserInfo(BaseModel):
user: User,
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
is_cloud_superuser: bool = False,
organization_name: str | None = None,
) -> "UserInfo":
return cls(
@@ -92,7 +90,6 @@ class UserInfo(BaseModel):
oidc_expiry=user.oidc_expiry if TRACK_EXTERNAL_IDP_EXPIRY else None,
current_token_created_at=current_token_created_at,
current_token_expiry_length=expiry_length,
is_cloud_superuser=is_cloud_superuser,
)

View File

@@ -35,10 +35,9 @@ from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.auth import get_total_users_count
from danswer.db.auth import get_total_users
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
@@ -191,6 +190,7 @@ def bulk_invite_users(
)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
normalized_emails = []
try:
for email in emails:
@@ -206,7 +206,6 @@ def bulk_invite_users(
if MULTI_TENANT:
try:
add_users_to_tenant(normalized_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
@@ -214,8 +213,6 @@ def bulk_invite_users(
detail="User has already been invited to a Danswer organization",
)
raise
except Exception as e:
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
initial_invited_users = get_invited_users()
@@ -227,7 +224,7 @@ def bulk_invite_users(
try:
logger.info("Registering tenant users")
register_tenant_users(
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session)
)
if ENABLE_EMAIL_INVITES:
try:
@@ -263,7 +260,7 @@ def remove_invited_user(
try:
if MULTI_TENANT:
register_tenant_users(
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session)
)
except Exception:
logger.error(
@@ -477,7 +474,6 @@ def verify_user_logged_in(
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
# to enforce user verification here - the frontend always wants to get the info about
# the current user regardless of if they are currently verified
if user is None:
# if auth type is disabled, return a dummy user with preferences from
# the key-value store
@@ -504,7 +500,6 @@ def verify_user_logged_in(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
is_cloud_superuser=user.email in SUPER_USERS,
organization_name=organization_name,
)

View File

@@ -1,6 +1,5 @@
import asyncio
import io
import json
import uuid
from collections.abc import Callable
from collections.abc import Generator
@@ -284,14 +283,13 @@ def delete_chat_session_by_id(
raise HTTPException(status_code=400, detail=str(e))
async def is_connected(request: Request) -> Callable[[], bool]:
async def is_disconnected(request: Request) -> Callable[[], bool]:
main_loop = asyncio.get_event_loop()
def is_connected_sync() -> bool:
def is_disconnected_sync() -> bool:
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
is_connected = not future.result(timeout=0.01)
return is_connected
return not future.result(timeout=0.01)
except asyncio.TimeoutError:
logger.error("Asyncio timed out")
return True
@@ -302,7 +300,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
)
return True
return is_connected_sync
return is_disconnected_sync
@router.post("/send-message")
@@ -311,7 +309,7 @@ def handle_new_chat_message(
request: Request,
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
) -> StreamingResponse:
"""
This endpoint is both used for all the following purposes:
@@ -327,7 +325,7 @@ def handle_new_chat_message(
request (Request): The current HTTP request context.
user (User | None): The current user, obtained via dependency injection.
_ (None): Rate limit check is run if user/group/global rate limits are enabled.
is_connected_func (Callable[[], bool]): Function to check client disconnection,
is_disconnected_func (Callable[[], bool]): Function to check client disconnection,
used to stop the streaming response if the client disconnects.
Returns:
@@ -342,6 +340,8 @@ def handle_new_chat_message(
):
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
import json
def stream_generator() -> Generator[str, None, None]:
try:
for packet in stream_chat_message(
@@ -354,7 +354,7 @@ def handle_new_chat_message(
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
is_connected=is_connected_func,
is_connected=is_disconnected_func,
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
@@ -362,9 +362,6 @@ def handle_new_chat_message(
logger.exception(f"Error in chat message streaming: {e}")
yield json.dumps({"error": str(e)})
finally:
logger.debug("Stream generator finished")
return StreamingResponse(stream_generator(), media_type="text/event-stream")
@@ -557,9 +554,9 @@ def upload_files_for_chat(
_: User | None = Depends(current_user),
) -> dict[str, list[FileDescriptor]]:
image_content_types = {"image/jpeg", "image/png", "image/webp"}
csv_content_types = {"text/csv"}
text_content_types = {
"text/plain",
"text/csv",
"text/markdown",
"text/x-markdown",
"text/x-config",
@@ -578,10 +575,8 @@ def upload_files_for_chat(
"application/epub+zip",
}
allowed_content_types = (
image_content_types.union(text_content_types)
.union(document_content_types)
.union(csv_content_types)
allowed_content_types = image_content_types.union(text_content_types).union(
document_content_types
)
for file in files:
@@ -591,10 +586,6 @@ def upload_files_for_chat(
elif file.content_type in text_content_types:
error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, "
".log, .tsv."
elif file.content_type in csv_content_types:
error_detail = (
"Unsupported CSV file type. Supported CSV types include .csv."
)
else:
error_detail = (
"Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, "
@@ -620,10 +611,6 @@ def upload_files_for_chat(
file_type = ChatFileType.IMAGE
# Convert image to JPEG
file_content, new_content_type = convert_to_jpeg(file)
elif file.content_type in csv_content_types:
file_type = ChatFileType.CSV
file_content = io.BytesIO(file.file.read())
new_content_type = file.content_type or ""
elif file.content_type in document_content_types:
file_type = ChatFileType.DOC
file_content = io.BytesIO(file.file.read())

View File

@@ -188,7 +188,7 @@ class ChatMessageDetail(BaseModel):
chat_session_id: UUID | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
tool_calls: list[ToolCallFinalResult]
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore

View File

@@ -21,7 +21,7 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -1,59 +0,0 @@
from typing import cast
from typing import TYPE_CHECKING
from langchain_core.messages import HumanMessage
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.tool_implementations.custom.custom_tool import (
CustomToolCallSummary,
)
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
def build_user_message_for_non_tool_calling_llm(
message: HumanMessage,
tool_name: str,
*args: "ToolResponse",
) -> str:
query, _ = message_to_prompt_and_imgs(message)
tool_run_summary = cast("CustomToolCallSummary", args[0].response).tool_result
return f"""
Here's the result from the {tool_name} tool:
{tool_run_summary}
Now respond to the following:
{query}
""".strip()
class BaseTool(Tool):
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",
tool_call_summary: "ToolCallSummary",
tool_responses: list["ToolResponse"],
using_tool_calling_llm: bool,
) -> "AnswerPromptBuilder":
if using_tool_calling_llm:
prompt_builder.append_message(tool_call_summary.tool_call_request)
prompt_builder.append_message(tool_call_summary.tool_call_result)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_non_tool_calling_llm(
prompt_builder.user_message_and_token_cnt[0],
self.name,
*tool_responses,
)
)
)
return prompt_builder

View File

@@ -9,13 +9,9 @@ from sqlalchemy.orm import Session
from danswer.db.models import Persona
from danswer.db.models import Tool as ToolDBModel
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.utils.logger import setup_logger

View File

@@ -11,34 +11,24 @@ from pydantic import BaseModel
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.base_tool import BaseTool
from danswer.tools.custom.base_tool_types import ToolResultType
from danswer.tools.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
)
from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT
from danswer.tools.custom.custom_tool_prompts import USE_TOOL
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import openapi_to_url
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_USER_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
TOOL_ARG_SYSTEM_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
TOOL_ARG_USER_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
from danswer.tools.tool_implementations.custom.openapi_parsing import (
openapi_to_method_specs,
)
from danswer.tools.tool_implementations.custom.openapi_parsing import openapi_to_url
from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BODY
from danswer.tools.tool_implementations.custom.openapi_parsing import (
validate_openapi_schema,
)
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.headers import header_list_to_header_dict
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
@@ -53,7 +43,7 @@ class CustomToolCallSummary(BaseModel):
tool_result: ToolResultType
class CustomTool(BaseTool):
class CustomTool(Tool):
def __init__(
self,
method_spec: MethodSpec,

View File

@@ -0,0 +1,21 @@
from typing import cast
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.models import ToolResponse
def build_user_message_for_custom_tool_for_non_tool_calling_llm(
query: str,
tool_name: str,
*args: ToolResponse,
) -> str:
tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result
return f"""
Here's the result from the {tool_name} tool:
{tool_run_summary}
Now respond to the following:
{query}
""".strip()

Some files were not shown because too many files have changed in this diff Show More