mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
1 Commits
bg_process
...
danswer_au
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
856c2debd9 |
76
.github/workflows/nightly-scan-licenses.yml
vendored
76
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@v2
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
@@ -18,9 +18,6 @@ env:
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -15,7 +15,7 @@ env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
|
||||
17
README.md
17
README.md
@@ -1,5 +1,4 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#danswer-ai/danswer&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
ARG CA_CERT_CONTENT=""
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
@@ -38,6 +39,15 @@ RUN apt-get update && \
|
||||
apt-get clean
|
||||
|
||||
|
||||
# Conditionally write the CA certificate and update certificates
|
||||
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
|
||||
echo "Adding custom CA certificate"; \
|
||||
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
update-ca-certificates; \
|
||||
else \
|
||||
echo "No custom CA certificate provided"; \
|
||||
fi
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
@@ -77,6 +87,7 @@ RUN apt-get update && \
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
"""single tool call per message
|
||||
|
||||
Revision ID: 33cb72ea4d80
|
||||
Revises: 5b29123cd710
|
||||
Create Date: 2024-11-01 12:51:01.535003
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33cb72ea4d80"
|
||||
down_revision = "5b29123cd710"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Delete extraneous ToolCall entries
|
||||
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool_call
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM tool_call
|
||||
WHERE message_id IS NOT NULL
|
||||
GROUP BY message_id
|
||||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Add a unique constraint on message_id
|
||||
op.create_unique_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
columns=["message_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Drop the unique constraint on message_id
|
||||
op.drop_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
type_="unique",
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
"""nullable search settings for historic index attempts
|
||||
|
||||
Revision ID: 5b29123cd710
|
||||
Revises: 949b4a92a401
|
||||
Create Date: 2024-10-30 19:37:59.630704
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5b29123cd710"
|
||||
down_revision = "949b4a92a401"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be nullable
|
||||
op.alter_column(
|
||||
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
# Add back the foreign key with ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Warning: This will delete all index attempts that don't have search settings
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE search_settings_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be not nullable
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
"search_settings_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add back the foreign key without ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -93,9 +93,9 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -510,23 +510,19 @@ cookie_transport = CookieTransport(
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
async def write_token(self, user: User) -> str:
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
|
||||
@@ -14,19 +14,18 @@ from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
@@ -109,27 +108,29 @@ def on_task_postrun(
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(tenant_id, int(document_set_id))
|
||||
rds = RedisDocumentSet(int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(tenant_id, int(usergroup_id))
|
||||
rug = RedisUserGroup(int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDelete.PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@@ -172,30 +173,44 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Exit early if multi-tenant since primary worker check not needed
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for all tenant primary workers to be ready...")
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
# Check if we have a primary worker lock for each tenant
|
||||
all_tenants_ready = all(
|
||||
get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
for tenant_id in tenant_ids
|
||||
)
|
||||
|
||||
if all_tenants_ready:
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
ready_tenants = sum(
|
||||
1
|
||||
for tenant_id in tenant_ids
|
||||
if get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Not all tenant primary workers are ready yet. "
|
||||
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"Not all tenant primary workers were ready within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
@@ -203,7 +218,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
logger.info("All tenant primary workers are ready. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
@@ -215,20 +230,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
try:
|
||||
if lock and lock.owned():
|
||||
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
|
||||
try:
|
||||
lock.release()
|
||||
logger.debug(f"Successfully released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
sender.primary_worker_locks[tenant_id] = None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -3,144 +3,26 @@ from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._update_tenant_tasks()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating tenant task update")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
logger.info("Tenant task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting tenant task update process")
|
||||
try:
|
||||
logger.info("Fetching all tenant IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} tenants")
|
||||
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
logger.debug(f"Creating task configuration for {task_name}")
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_beat_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
)
|
||||
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
logger.debug("Comparing current and new schedules")
|
||||
current_tasks = set(name for name, _ in current_schedule)
|
||||
new_tasks = set(new_schedule.keys())
|
||||
needs_update = current_tasks != new_tasks
|
||||
logger.debug(f"Schedule update needed: {needs_update}")
|
||||
return needs_update
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
# celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
@@ -153,4 +35,68 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
|
||||
@@ -85,6 +85,5 @@ celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -13,15 +13,21 @@ from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -73,45 +79,91 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
# by the primary worker
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
@@ -164,36 +216,52 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
lock = worker.primary_worker_lock
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
task_logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||
worker.primary_worker_lock = lock
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.error("Primary worker lock: Acquire failed!")
|
||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Periodic task failed.")
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=1)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
super().setup_schedule()
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reloading schedule to check for new tenants...")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Checking for tenant task updates...")
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = getattr(self, "_store", {"entries": {}}).get(
|
||||
"entries", {}
|
||||
)
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name in current_schedule.keys():
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Found new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Updating schedule",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
if not hasattr(self, "_store"):
|
||||
self._store: dict[str, dict] = {"entries": {}}
|
||||
self.update_from_dict(new_beat_schedule)
|
||||
logger.info(f"New schedule: {new_beat_schedule}")
|
||||
|
||||
logger.info("Tenant tasks updated successfully")
|
||||
else:
|
||||
logger.debug("No schedule updates needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
current_tasks = set(current_schedule.keys())
|
||||
new_tasks = set(new_schedule.keys())
|
||||
return current_tasks != new_tasks
|
||||
@@ -1,10 +1,568 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: str):
|
||||
self._id: str = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[2]
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class is used to scan documents by cc_pair in the db and collect them into
|
||||
a unified set for syncing.
|
||||
|
||||
It differs from the other redis helpers in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorPruning(RedisObjectHelper):
|
||||
"""Celery will kick off a long running generator task to crawl the connector and
|
||||
find any missing docs, which will each then get a new cleanup task. The progress of
|
||||
those tasks will then be monitored to completion.
|
||||
|
||||
Example rough happy path order:
|
||||
Check connectorpruning_fence_1
|
||||
Send generator task with id connectorpruning+generator_1_{uuid}
|
||||
|
||||
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
|
||||
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
|
||||
in taskset connectorpruning_taskset_1
|
||||
on completion, generator sets connectorpruning_generator_complete_1
|
||||
|
||||
celery postrun removes subtasks from taskset
|
||||
monitor beat task cleans up when taskset reaches 0 items
|
||||
"""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
self.documents_to_prune: set[str] = set()
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in self.documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorIndexing(RedisObjectHelper):
|
||||
"""Celery will kick off a long running indexing task to crawl the connector and
|
||||
find any new or updated docs docs, which will each then get a new sync task or be
|
||||
indexed inline.
|
||||
|
||||
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
|
||||
e.g. "2/5"
|
||||
"""
|
||||
|
||||
PREFIX = "connectorindexing"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
|
||||
super().__init__(f"{cc_pair_id}/{search_settings_id}")
|
||||
|
||||
@property
|
||||
def generator_lock_key(self) -> str:
|
||||
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
def is_indexing(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorStop(RedisObjectHelper):
|
||||
"""Used to signal any running tasks for a connector to stop. We should refactor
|
||||
connector related redis helpers into a single class.
|
||||
"""
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
@@ -17,7 +18,7 @@ from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -40,14 +41,14 @@ def _get_deletion_status(
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
if not redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
@@ -10,6 +10,13 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -18,8 +25,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@@ -57,7 +62,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -66,10 +71,10 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
redis_connector.stop.set_fence(True)
|
||||
r.set(rcs.fence_key, cc_pair_id)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
r.delete(rcs.fence_key)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
@@ -101,10 +106,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if redis_connector.delete.fenced:
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to load the state of the object inside the fence
|
||||
@@ -118,49 +123,47 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
fence_payload = RedisConnectorDeletionFenceData(
|
||||
fence_value = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
try:
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
|
||||
if r.get(rci.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if r.get(rcp.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
redis_connector.delete.taskset_clear()
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.delete.generate_tasks(
|
||||
app, db_session, lock_beat
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.fence_key)
|
||||
raise
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.fence_key)
|
||||
return None
|
||||
else:
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
@@ -175,7 +178,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
fence_value.num_tasks = tasks_generated
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -2,6 +2,8 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
@@ -12,6 +14,12 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
@@ -42,8 +50,6 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -99,22 +105,19 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if current_search_settings.provider_type is None and not MULTI_TENANT:
|
||||
if old_search_settings:
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
# only warm up if search settings were changed
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -123,7 +126,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
@@ -136,10 +138,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
rci = RedisConnectorIndexing(
|
||||
cc_pair_id, search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
if r.exists(rci.fence_key):
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -302,15 +304,15 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
if r.exists(rci.fence_key):
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
@@ -318,17 +320,19 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.taskset_key)
|
||||
|
||||
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexingFenceData(
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_fence(payload)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
@@ -341,8 +345,6 @@ def try_creating_indexing_task(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
@@ -359,12 +361,11 @@ def try_creating_indexing_task(
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
fence_value.index_attempt_id = index_attempt_id
|
||||
fence_value.celery_task_id = result.id
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
except Exception:
|
||||
redis_connector_index.set_fence(payload)
|
||||
r.delete(rci.fence_key)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -387,12 +388,7 @@ def connector_indexing_proxy_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing proxy - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
@@ -406,56 +402,29 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
f"Indexing proxy - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
break
|
||||
job.release()
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -477,78 +446,78 @@ def connector_indexing_task(
|
||||
|
||||
Returns None if the task did not run (possibly due to a conflict).
|
||||
Otherwise, returns an int >= 0 representing the number of indexed docs.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
attempt = None
|
||||
n_final_progress: int | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
n_final_progress = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
if r.exists(rcd.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
f"fence={rcd.fence_key}"
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
if r.exists(rcs.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
f"fence={rcs.fence_key}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
while True:
|
||||
# wait for the fence to come up
|
||||
if not redis_connector_index.fenced:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
|
||||
)
|
||||
raise
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
rci.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
task_logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
|
||||
return None
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
fence_data.started = datetime.now(timezone.utc)
|
||||
r.set(rci.fence_key, fence_data.model_dump_json())
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -576,19 +545,11 @@ def connector_indexing_task(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
rcs.fence_key, rci.generator_progress_key, lock, r
|
||||
)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
@@ -600,29 +561,27 @@ def connector_indexing_task(
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
n_final_progress = int(cast(int, generator_progress_value))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
|
||||
if attempt:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
redis_connector_index.reset()
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return n_final_progress
|
||||
|
||||
@@ -11,6 +11,9 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
@@ -30,9 +33,7 @@ from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import pruning_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -145,11 +146,8 @@ def try_creating_prune_generator_task(
|
||||
is used to trigger prunes immediately, e.g. via the web ui.
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
count = redis_connector.prune.get_active_task_count()
|
||||
if count > 0:
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
return None
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
@@ -166,10 +164,15 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
if redis_connector.prune.fenced: # skip pruning if already pruning
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# skip pruning if already pruning
|
||||
if r.exists(rcp.fence_key):
|
||||
return None
|
||||
|
||||
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
|
||||
# skip pruning if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
@@ -177,10 +180,10 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.taskset_clear()
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
@@ -196,7 +199,7 @@ def try_creating_prune_generator_task(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
r.set(rcp.fence_key, 1)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
@@ -226,17 +229,12 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
pruning_ctx.set(pruning_ctx_dict)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -269,11 +267,10 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
rcs.fence_key, rcp.generator_progress_key, lock, r
|
||||
)
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
@@ -300,29 +297,31 @@ def connector_pruning_generator_task(
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
self.app, db_session, r, None, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
||||
@@ -0,0 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
@@ -19,6 +19,18 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -55,14 +67,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
@@ -187,7 +192,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -223,10 +228,10 @@ def try_generate_document_set_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if rds.fenced:
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
@@ -264,7 +269,7 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -278,9 +283,10 @@ def try_generate_user_group_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(tenant_id, usergroup_id)
|
||||
if rug.fenced:
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
@@ -320,7 +326,7 @@ def try_generate_user_group_sync_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -346,7 +352,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -356,12 +362,16 @@ def monitor_document_set_taskset(
|
||||
|
||||
document_set_id = int(document_set_id_str)
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
if not rds.fenced:
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial_count = rds.payload
|
||||
if initial_count is None:
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
@@ -389,38 +399,48 @@ def monitor_document_set_taskset(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
key_bytes: bytes, r: Redis, tenant_id: str | None
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rcd.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
# the fence is setting up but isn't ready yet
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -504,15 +524,15 @@ def monitor_connector_deletion_taskset(
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.taskset_clear()
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
@@ -521,37 +541,46 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcp.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
generator_value = r.get(rcp.generator_complete_key)
|
||||
if generator_value is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
try:
|
||||
initial_count = int(cast(int, generator_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcp.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
@@ -566,37 +595,53 @@ def monitor_ccpair_indexing_taskset(
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
|
||||
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
progress_count = int(cast(int, generator_progress_value))
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"progress={progress_count} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
|
||||
)
|
||||
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None:
|
||||
generator_complete_value = r.get(rci.generator_complete_key)
|
||||
if generator_complete_value is None:
|
||||
if result_state in READY_STATES:
|
||||
# IF the task state is READY, THEN generator_complete should be set
|
||||
# if it isn't, then the worker crashed
|
||||
@@ -607,7 +652,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
@@ -615,10 +660,22 @@ def monitor_ccpair_indexing_taskset(
|
||||
failure_reason="Connector indexing aborted or exceptioned.",
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
try:
|
||||
status_value = int(cast(int, generator_complete_value))
|
||||
status_enum = HTTPStatus(status_value)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"monitor_ccpair_indexing_taskset: "
|
||||
f"generator_complete_value=f{generator_complete_value} could not be parsed."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
|
||||
@@ -627,7 +684,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
@@ -639,7 +700,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
Returns True if the task actually did work, False
|
||||
"""
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -690,12 +751,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
for a in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
rci = RedisConnectorIndexing(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
)
|
||||
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
|
||||
if not r.exists(
|
||||
RedisConnectorIndex.fence_key_with_ids(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
)
|
||||
):
|
||||
if not r.exists(rci.fence_key):
|
||||
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
@@ -703,15 +763,15 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
@@ -722,19 +782,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = celery_app
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.beat", "celery_app"
|
||||
)
|
||||
|
||||
@@ -118,13 +118,7 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
search_settings = index_attempt.search_settings
|
||||
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
|
||||
@@ -18,7 +18,6 @@ from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
@@ -78,49 +77,31 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
@@ -279,7 +260,6 @@ ChatPacket = (
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -552,13 +532,6 @@ def stream_chat_message_objects(
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
@@ -577,16 +550,13 @@ def stream_chat_message_objects(
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_style_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -656,11 +626,7 @@ def stream_chat_message_objects(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=bing_api_key,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
@@ -701,7 +667,13 @@ def stream_chat_message_objects(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=answer_style_config,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
@@ -805,8 +777,7 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
@@ -836,7 +807,6 @@ def stream_chat_message_objects(
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
logger.debug("Post-LLM answer processing")
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
@@ -864,15 +834,17 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else None
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -9,19 +9,19 @@ prompts:
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
|
||||
|
||||
You always clearly communicate ANY UNCERTAINTY in your answer.
|
||||
# Task Prompt (as shown in UI)
|
||||
task: >
|
||||
Answer my query based on the documents provided.
|
||||
The documents may not all be relevant, ignore any documents that are not directly relevant
|
||||
to the most recent user query.
|
||||
|
||||
|
||||
I have not read or seen any of the documents and do not want to read them.
|
||||
|
||||
|
||||
If there are no relevant documents, refer to the chat history and your internal knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
|
||||
@@ -30,21 +30,21 @@ prompts:
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "ImageGeneration"
|
||||
description: "Generates images from user descriptions!"
|
||||
description: "Generates images based on user prompts!"
|
||||
system: >
|
||||
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
|
||||
|
||||
For appropriate requests, you will generate an image that matches the user's requirements.
|
||||
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
|
||||
|
||||
You aim to be helpful while maintaining appropriate content standards.
|
||||
You are an advanced image generation system capable of creating diverse and detailed images.
|
||||
|
||||
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
|
||||
|
||||
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
|
||||
task: >
|
||||
Based on the user's description, create a high-quality image that accurately reflects their request.
|
||||
Pay close attention to the specified details, styles, and desired elements.
|
||||
|
||||
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
include_citations: false
|
||||
|
||||
@@ -64,13 +64,14 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: "Summarize relevant information from retrieved context!"
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
|
||||
@@ -83,6 +84,7 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
@@ -90,10 +92,10 @@ prompts:
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
|
||||
If there are no documents provided,
|
||||
simply tell the user that there are no documents to reference.
|
||||
|
||||
|
||||
You NEVER generate new text or phrases outside of the citation.
|
||||
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
|
||||
task: >
|
||||
|
||||
@@ -251,6 +251,9 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
|
||||
# for some connectors
|
||||
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||
|
||||
GOOGLE_DRIVE_INCLUDE_SHARED = False
|
||||
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
|
||||
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
|
||||
|
||||
# TODO these should be available for frontend configuration, via advanced options expandable
|
||||
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||
@@ -478,7 +481,3 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
@@ -125,7 +125,6 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -250,11 +249,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
|
||||
@@ -23,16 +23,7 @@ def datetime_to_utc(dt: datetime) -> datetime:
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
try:
|
||||
dt = parse(datetime_str)
|
||||
except ValueError:
|
||||
# Handle malformed timezone by attempting to fix common format issues
|
||||
if "0000" in datetime_str:
|
||||
# Convert "0000" to "+0000" for proper timezone parsing
|
||||
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(fixed_dt_str)
|
||||
else:
|
||||
raise
|
||||
dt = parse(datetime_str)
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.gitlab.connector import GitlabConnector
|
||||
from danswer.connectors.gmail.connector import GmailConnector
|
||||
@@ -100,7 +99,6 @@ def identify_connector_class(
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_FRESHDESK_ID_PREFIX = "FRESHDESK_"
|
||||
|
||||
|
||||
_TICKET_FIELDS_TO_INCLUDE = {
|
||||
"fr_escalated",
|
||||
"spam",
|
||||
"priority",
|
||||
"source",
|
||||
"status",
|
||||
"type",
|
||||
"is_escalated",
|
||||
"tags",
|
||||
"nr_due_by",
|
||||
"nr_escalated",
|
||||
"cc_emails",
|
||||
"fwd_emails",
|
||||
"reply_cc_emails",
|
||||
"ticket_cc_emails",
|
||||
"support_email",
|
||||
"to_emails",
|
||||
}
|
||||
|
||||
_SOURCE_NUMBER_TYPE_MAP: dict[int, str] = {
|
||||
1: "Email",
|
||||
2: "Portal",
|
||||
3: "Phone",
|
||||
7: "Chat",
|
||||
9: "Feedback Widget",
|
||||
10: "Outbound Email",
|
||||
}
|
||||
|
||||
_PRIORITY_NUMBER_TYPE_MAP: dict[int, str] = {
|
||||
1: "low",
|
||||
2: "medium",
|
||||
3: "high",
|
||||
4: "urgent",
|
||||
}
|
||||
|
||||
_STATUS_NUMBER_TYPE_MAP: dict[int, str] = {
|
||||
2: "open",
|
||||
3: "pending",
|
||||
4: "resolved",
|
||||
5: "closed",
|
||||
}
|
||||
|
||||
|
||||
def _create_metadata_from_ticket(ticket: dict) -> dict:
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
# Combine all emails into a list so there are no repeated emails
|
||||
email_data: set[str] = set()
|
||||
|
||||
for key, value in ticket.items():
|
||||
# Skip fields that aren't useful for embedding
|
||||
if key not in _TICKET_FIELDS_TO_INCLUDE:
|
||||
continue
|
||||
|
||||
# Skip empty fields
|
||||
if not value or value == "[]":
|
||||
continue
|
||||
|
||||
# Convert strings or lists to strings
|
||||
stringified_value: str | list[str]
|
||||
if isinstance(value, list):
|
||||
stringified_value = [str(item) for item in value]
|
||||
else:
|
||||
stringified_value = str(value)
|
||||
|
||||
if "email" in key:
|
||||
if isinstance(stringified_value, list):
|
||||
email_data.update(stringified_value)
|
||||
else:
|
||||
email_data.add(stringified_value)
|
||||
else:
|
||||
metadata[key] = stringified_value
|
||||
|
||||
if email_data:
|
||||
metadata["emails"] = list(email_data)
|
||||
|
||||
# Convert source numbers to human-parsable string
|
||||
if source_number := ticket.get("source"):
|
||||
metadata["source"] = _SOURCE_NUMBER_TYPE_MAP.get(
|
||||
source_number, "Unknown Source Type"
|
||||
)
|
||||
|
||||
# Convert priority numbers to human-parsable string
|
||||
if priority_number := ticket.get("priority"):
|
||||
metadata["priority"] = _PRIORITY_NUMBER_TYPE_MAP.get(
|
||||
priority_number, "Unknown Priority"
|
||||
)
|
||||
|
||||
# Convert status to human-parsable string
|
||||
if status_number := ticket.get("status"):
|
||||
metadata["status"] = _STATUS_NUMBER_TYPE_MAP.get(
|
||||
status_number, "Unknown Status"
|
||||
)
|
||||
|
||||
due_by = datetime.fromisoformat(ticket["due_by"].replace("Z", "+00:00"))
|
||||
metadata["overdue"] = str(datetime.now(timezone.utc) > due_by)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
|
||||
# Use the ticket description as the text
|
||||
text = f"Ticket description: {parse_html_page_basic(ticket.get('description_text', ''))}"
|
||||
metadata = _create_metadata_from_ticket(ticket)
|
||||
|
||||
# This is also used in the ID because it is more unique than the just the ticket ID
|
||||
link = f"https://{domain}.freshdesk.com/helpdesk/tickets/{ticket['id']}"
|
||||
|
||||
return Document(
|
||||
id=_FRESHDESK_ID_PREFIX + link,
|
||||
sections=[
|
||||
Section(
|
||||
link=link,
|
||||
text=text,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.FRESHDESK,
|
||||
semantic_identifier=ticket["subject"],
|
||||
metadata=metadata,
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
ticket["updated_at"].replace("Z", "+00:00")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FreshdeskConnector(PollConnector, LoadConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, str | int]) -> None:
|
||||
api_key = credentials.get("freshdesk_api_key")
|
||||
domain = credentials.get("freshdesk_domain")
|
||||
password = credentials.get("freshdesk_password")
|
||||
|
||||
if not all(isinstance(cred, str) for cred in [domain, api_key, password]):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"All Freshdesk credentials must be strings"
|
||||
)
|
||||
|
||||
self.api_key = str(api_key)
|
||||
self.domain = str(domain)
|
||||
self.password = str(password)
|
||||
|
||||
def _fetch_tickets(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> Iterator[List[dict]]:
|
||||
"""
|
||||
'end' is not currently used, so we may double fetch tickets created after the indexing
|
||||
starts but before the actual call is made.
|
||||
|
||||
To use 'end' would require us to use the search endpoint but it has limitations,
|
||||
namely having to fetch all IDs and then individually fetch each ticket because there is no
|
||||
'include' field available for this endpoint:
|
||||
https://developers.freshdesk.com/api/#filter_tickets
|
||||
"""
|
||||
if self.api_key is None or self.domain is None or self.password is None:
|
||||
raise ConnectorMissingCredentialError("freshdesk")
|
||||
|
||||
base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets"
|
||||
params: dict[str, int | str] = {
|
||||
"include": "description",
|
||||
"per_page": 50,
|
||||
"page": 1,
|
||||
}
|
||||
|
||||
if start:
|
||||
params["updated_since"] = start.isoformat()
|
||||
|
||||
while True:
|
||||
response = requests.get(
|
||||
base_url, auth=(self.api_key, self.password), params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 204:
|
||||
break
|
||||
|
||||
tickets = json.loads(response.content)
|
||||
logger.info(
|
||||
f"Fetched {len(tickets)} tickets from Freshdesk API (Page {params['page']})"
|
||||
)
|
||||
|
||||
yield tickets
|
||||
|
||||
if len(tickets) < int(params["per_page"]):
|
||||
break
|
||||
|
||||
params["page"] = int(params["page"]) + 1
|
||||
|
||||
def _process_tickets(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for ticket_batch in self._fetch_tickets(start, end):
|
||||
for ticket in ticket_batch:
|
||||
doc_batch.append(_create_doc_from_ticket(ticket, self.domain))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._process_tickets()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
yield from self._process_tickets(start_datetime, end_datetime)
|
||||
@@ -1,8 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
from base64 import urlsafe_b64decode
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
@@ -10,7 +6,6 @@ from typing import Dict
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -39,64 +34,6 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _execute_with_retry(request: Any) -> Any:
|
||||
max_attempts = 10
|
||||
attempt = 0
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Note for reasons unknown, the Google API will sometimes return a 429
|
||||
# and even after waiting the retry period, it will return another 429.
|
||||
# It could be due to a few possibilities:
|
||||
# 1. Other things are also requesting from the Gmail API with the same key
|
||||
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
|
||||
# 3. The retry-after has a maximum and we've already hit the limit for the day
|
||||
# or it's something else...
|
||||
try:
|
||||
return request.execute()
|
||||
except HttpError as error:
|
||||
attempt += 1
|
||||
|
||||
if error.resp.status == 429:
|
||||
# Attempt to get 'Retry-After' from headers
|
||||
retry_after = error.resp.get("Retry-After")
|
||||
if retry_after:
|
||||
sleep_time = int(retry_after)
|
||||
else:
|
||||
# Extract 'Retry after' timestamp from error message
|
||||
match = re.search(
|
||||
r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)",
|
||||
str(error),
|
||||
)
|
||||
if match:
|
||||
retry_after_timestamp = match.group(1)
|
||||
retry_after_dt = datetime.strptime(
|
||||
retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
sleep_time = max(
|
||||
int((retry_after_dt - current_time).total_seconds()),
|
||||
0,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"No Retry-After header or timestamp found in error message: {error}"
|
||||
)
|
||||
sleep_time = 60
|
||||
|
||||
sleep_time += 3 # Add a buffer to be safe
|
||||
|
||||
logger.info(
|
||||
f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds."
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
else:
|
||||
raise
|
||||
|
||||
# If we've exhausted all attempts
|
||||
raise Exception(f"Failed to execute request after {max_attempts} attempts")
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
@@ -219,7 +156,7 @@ class GmailConnector(LoadConnector, PollConnector):
|
||||
query = GmailConnector._build_time_range_query(time_range_start, time_range_end)
|
||||
service = discovery.build("gmail", "v1", credentials=self.creds)
|
||||
while page_token is not None:
|
||||
result = _execute_with_retry(
|
||||
result = (
|
||||
service.users()
|
||||
.messages()
|
||||
.list(
|
||||
@@ -228,17 +165,18 @@ class GmailConnector(LoadConnector, PollConnector):
|
||||
q=query,
|
||||
maxResults=self.batch_size,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
page_token = result.get("nextPageToken")
|
||||
messages = result.get("messages", [])
|
||||
doc_batch = []
|
||||
for message in messages:
|
||||
message_id = message["id"]
|
||||
msg = _execute_with_retry(
|
||||
msg = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(userId="me", id=message_id, format="full")
|
||||
.execute()
|
||||
)
|
||||
doc = self._email_to_document(msg)
|
||||
doc_batch.append(doc)
|
||||
|
||||
@@ -1,305 +1,556 @@
|
||||
import io
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR
|
||||
from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS
|
||||
from danswer.connectors.google_drive.constants import SCOPE_DOC_URL
|
||||
from danswer.connectors.google_drive.constants import SLIM_BATCH_SIZE
|
||||
from danswer.connectors.google_drive.constants import USER_FIELDS
|
||||
from danswer.connectors.google_drive.doc_conversion import (
|
||||
convert_drive_item_to_document,
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive
|
||||
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
|
||||
if not string:
|
||||
return []
|
||||
return [s.strip() for s in string.split(",") if s.strip()]
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
|
||||
|
||||
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
return [url.split("/")[-1] for url in urls]
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _run_drive_file_query(
|
||||
service: discovery.Resource,
|
||||
query: str,
|
||||
continue_on_failure: bool,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
logger.debug(f"Running Google Drive fetch with query: {query}")
|
||||
results = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.list(
|
||||
corpora="allDrives"
|
||||
if include_shared
|
||||
else "user", # needed to search through shared drives
|
||||
pageSize=batch_size,
|
||||
supportsAllDrives=include_shared,
|
||||
includeItemsFromAllDrives=include_shared,
|
||||
fields=(
|
||||
"nextPageToken, files(mimeType, id, name, permissions, "
|
||||
"modifiedTime, webViewLink, shortcutDetails)"
|
||||
),
|
||||
pageToken=next_page_token,
|
||||
q=query,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
next_page_token = results.get("nextPageToken")
|
||||
files = results["files"]
|
||||
for file in files:
|
||||
if follow_shortcuts and "shortcutDetails" in file:
|
||||
try:
|
||||
file_shortcut_points_to = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=file["shortcutDetails"]["targetId"],
|
||||
supportsAllDrives=include_shared,
|
||||
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
yield file_shortcut_points_to
|
||||
except HttpError:
|
||||
logger.error(
|
||||
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
|
||||
)
|
||||
if continue_on_failure:
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
yield file
|
||||
|
||||
|
||||
def _get_folder_id(
|
||||
service: discovery.Resource,
|
||||
parent_id: str,
|
||||
folder_name: str,
|
||||
include_shared: bool,
|
||||
follow_shortcuts: bool,
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the ID of a folder given its name and the ID of its parent folder.
|
||||
"""
|
||||
query = f"'{parent_id}' in parents and name='{folder_name}' and "
|
||||
if follow_shortcuts:
|
||||
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
|
||||
else:
|
||||
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
|
||||
|
||||
# TODO: support specifying folder path in shared drive rather than just `My Drive`
|
||||
results = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.list(
|
||||
q=query,
|
||||
spaces="drive",
|
||||
fields="nextPageToken, files(id, name, shortcutDetails)",
|
||||
supportsAllDrives=include_shared,
|
||||
includeItemsFromAllDrives=include_shared,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
items = results.get("files", [])
|
||||
|
||||
folder_id = None
|
||||
if items:
|
||||
if follow_shortcuts and "shortcutDetails" in items[0]:
|
||||
folder_id = items[0]["shortcutDetails"]["targetId"]
|
||||
else:
|
||||
folder_id = items[0]["id"]
|
||||
return folder_id
|
||||
|
||||
|
||||
def _get_folders(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
|
||||
if follow_shortcuts:
|
||||
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
|
||||
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
for file in _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
):
|
||||
# Need to check this since file may have been a target of a shortcut
|
||||
# and not necessarily a folder
|
||||
if file["mimeType"] == DRIVE_FOLDER_TYPE:
|
||||
yield file
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def _get_files(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
|
||||
if time_range_start is not None:
|
||||
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
|
||||
query += f"and modifiedTime >= '{time_start}' "
|
||||
if time_range_end is not None:
|
||||
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
|
||||
query += f"and modifiedTime <= '{time_stop}' "
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
files = _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def get_all_files_batched(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
# if True, will fetch files in sub-folders of the specified folder ID.
|
||||
# Only applies if folder_id is specified.
|
||||
traverse_subfolders: bool = True,
|
||||
folder_ids_traversed: list[str] | None = None,
|
||||
) -> Iterator[list[GoogleDriveFileType]]:
|
||||
"""Gets all files matching the criteria specified by the args from Google Drive
|
||||
in batches of size `batch_size`.
|
||||
"""
|
||||
found_files = _get_files(
|
||||
service=service,
|
||||
continue_on_failure=continue_on_failure,
|
||||
time_range_start=time_range_start,
|
||||
time_range_end=time_range_end,
|
||||
folder_id=folder_id,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
yield from batch_generator(
|
||||
items=found_files,
|
||||
batch_size=batch_size,
|
||||
pre_batch_yield=lambda batch_files: logger.debug(
|
||||
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
|
||||
),
|
||||
)
|
||||
|
||||
if traverse_subfolders and folder_id is not None:
|
||||
folder_ids_traversed = folder_ids_traversed or []
|
||||
subfolders = _get_folders(
|
||||
service=service,
|
||||
folder_id=folder_id,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
for subfolder in subfolders:
|
||||
if subfolder["id"] not in folder_ids_traversed:
|
||||
logger.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
folder_ids_traversed.append(subfolder["id"])
|
||||
yield from get_all_files_batched(
|
||||
service=service,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
time_range_start=time_range_start,
|
||||
time_range_end=time_range_end,
|
||||
folder_id=subfolder["id"],
|
||||
traverse_subfolders=traverse_subfolders,
|
||||
folder_ids_traversed=folder_ids_traversed,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipping subfolder since already traversed: " + subfolder["name"]
|
||||
)
|
||||
|
||||
|
||||
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
mime_type = file["mimeType"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(
|
||||
file=io.BytesIO(response), file_name=file.get("name", file["id"])
|
||||
)
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
include_shared_drives: bool = True,
|
||||
shared_drive_urls: str | None = None,
|
||||
include_my_drives: bool = True,
|
||||
my_drive_emails: str | None = None,
|
||||
shared_folder_urls: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# OLD PARAMETERS
|
||||
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
|
||||
# if specified, will only index files in these folders
|
||||
folder_paths: list[str] | None = None,
|
||||
include_shared: bool | None = None,
|
||||
follow_shortcuts: bool | None = None,
|
||||
only_org_public: bool | None = None,
|
||||
continue_on_failure: bool | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
# Check for old input parameters
|
||||
if (
|
||||
folder_paths is not None
|
||||
or include_shared is not None
|
||||
or follow_shortcuts is not None
|
||||
or only_org_public is not None
|
||||
or continue_on_failure is not None
|
||||
):
|
||||
logger.exception(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
)
|
||||
|
||||
if (
|
||||
not include_shared_drives
|
||||
and not include_my_drives
|
||||
and not shared_folder_urls
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of include_shared_drives, include_my_drives,"
|
||||
" or shared_folder_urls must be true"
|
||||
)
|
||||
|
||||
self.folder_paths = folder_paths or []
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.include_shared_drives = include_shared_drives
|
||||
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
|
||||
self.shared_drive_ids = _extract_ids_from_urls(shared_drive_url_list)
|
||||
|
||||
self.include_my_drives = include_my_drives
|
||||
self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails)
|
||||
|
||||
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
|
||||
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list)
|
||||
|
||||
self.primary_admin_email: str | None = None
|
||||
self.google_domain: str | None = None
|
||||
|
||||
self.include_shared = include_shared
|
||||
self.follow_shortcuts = follow_shortcuts
|
||||
self.only_org_public = only_org_public
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self._TRAVERSED_PARENT_IDS: set[str] = set()
|
||||
@staticmethod
|
||||
def _process_folder_paths(
|
||||
service: discovery.Resource,
|
||||
folder_paths: list[str],
|
||||
include_shared: bool,
|
||||
follow_shortcuts: bool,
|
||||
) -> list[str]:
|
||||
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
|
||||
folder_ids: list[str] = []
|
||||
for path in folder_paths:
|
||||
folder_names = path.split("/")
|
||||
parent_id = "root"
|
||||
for folder_name in folder_names:
|
||||
found_parent_id = _get_folder_id(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
folder_name=folder_name,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
)
|
||||
if found_parent_id is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"Folder '{folder_name}' in path '{path}' "
|
||||
"not found in Google Drive"
|
||||
)
|
||||
)
|
||||
parent_id = found_parent_id
|
||||
folder_ids.append(parent_id)
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._TRAVERSED_PARENT_IDS.add(folder_id)
|
||||
return folder_ids
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self.google_domain = primary_admin_email.split("@")[1]
|
||||
self.primary_admin_email = primary_admin_email
|
||||
|
||||
self.creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorough
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
def get_google_resource(
|
||||
self,
|
||||
service_name: str = "drive",
|
||||
service_version: str = "v3",
|
||||
user_email: str | None = None,
|
||||
) -> Resource:
|
||||
if isinstance(self.creds, ServiceAccountCredentials):
|
||||
creds = self.creds.with_subject(user_email or self.primary_admin_email)
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
elif isinstance(self.creds, OAuthCredentials):
|
||||
service = build(service_name, service_version, credentials=self.creds)
|
||||
else:
|
||||
raise PermissionError("No credentials found")
|
||||
|
||||
return service
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
admin_service = self.get_google_resource("admin", "directory_v1")
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
def _fetch_drive_items(
|
||||
self,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
primary_drive_service = self.get_google_resource()
|
||||
|
||||
if self.include_shared_drives:
|
||||
shared_drive_urls = self.shared_drive_ids
|
||||
if not shared_drive_urls:
|
||||
# if no parent ids are specified, get all shared drives using the admin account
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=primary_drive_service.drives().list,
|
||||
list_key="drives",
|
||||
useDomainAdminAccess=True,
|
||||
fields="drives(id)",
|
||||
):
|
||||
shared_drive_urls.append(drive["id"])
|
||||
|
||||
# For each shared drive, retrieve all files
|
||||
for shared_drive_id in shared_drive_urls:
|
||||
for file in get_files_in_shared_drive(
|
||||
service=primary_drive_service,
|
||||
drive_id=shared_drive_id,
|
||||
is_slim=is_slim,
|
||||
cache_folders=bool(self.shared_folder_ids),
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
yield file
|
||||
|
||||
if self.shared_folder_ids:
|
||||
# Crawl all the shared parent ids for files
|
||||
for folder_id in self.shared_folder_ids:
|
||||
yield from crawl_folders_for_files(
|
||||
service=primary_drive_service,
|
||||
parent_id=folder_id,
|
||||
personal_drive=False,
|
||||
traversed_parent_ids=self._TRAVERSED_PARENT_IDS,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
all_user_emails = []
|
||||
# get all personal docs from each users' personal drive
|
||||
if self.include_my_drives:
|
||||
if isinstance(self.creds, ServiceAccountCredentials):
|
||||
all_user_emails = self.my_drive_emails or []
|
||||
|
||||
# If using service account and no emails specified, fetch all users
|
||||
if not all_user_emails:
|
||||
all_user_emails = self._get_all_user_emails()
|
||||
|
||||
elif self.primary_admin_email:
|
||||
# If using OAuth, only fetch the primary admin email
|
||||
all_user_emails = [self.primary_admin_email]
|
||||
|
||||
for email in all_user_emails:
|
||||
logger.info(f"Fetching personal files for user: {email}")
|
||||
user_drive_service = self.get_google_resource(user_email=email)
|
||||
|
||||
yield from get_files_in_my_drive(
|
||||
service=user_drive_service,
|
||||
email=email,
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
def _fetch_docs_from_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress")
|
||||
service = self.get_google_resource(user_email=user_email)
|
||||
if doc := convert_drive_item_to_document(
|
||||
file=file,
|
||||
service=service,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
if self.creds is None:
|
||||
raise PermissionError("Not logged into Google Drive")
|
||||
|
||||
yield doc_batch
|
||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||
folder_ids: Sequence[str | None] = self._process_folder_paths(
|
||||
service, self.folder_paths, self.include_shared, self.follow_shortcuts
|
||||
)
|
||||
if not folder_ids:
|
||||
folder_ids = [None]
|
||||
|
||||
file_batches = chain(
|
||||
*[
|
||||
get_all_files_batched(
|
||||
service=service,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
include_shared=self.include_shared,
|
||||
follow_shortcuts=self.follow_shortcuts,
|
||||
batch_size=self.batch_size,
|
||||
time_range_start=start,
|
||||
time_range_end=end,
|
||||
folder_id=folder_id,
|
||||
traverse_subfolders=True,
|
||||
)
|
||||
for folder_id in folder_ids
|
||||
]
|
||||
)
|
||||
for files_batch in file_batches:
|
||||
doc_batch = []
|
||||
for file in files_batch:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
continue
|
||||
|
||||
if self.only_org_public:
|
||||
if "permissions" not in file:
|
||||
continue
|
||||
if not any(
|
||||
permission["type"] == "domain"
|
||||
for permission in file["permissions"]
|
||||
):
|
||||
continue
|
||||
try:
|
||||
text_contents = extract_text(file, service) or ""
|
||||
except HttpError as e:
|
||||
reason = (
|
||||
e.error_details[0]["reason"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
message = (
|
||||
e.error_details[0]["message"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
|
||||
# these errors don't represent a failure in the connector, but simply files
|
||||
# that can't / shouldn't be indexed
|
||||
ERRORS_TO_CONTINUE_ON = [
|
||||
"cannotExportFile",
|
||||
"exportSizeLimitExceeded",
|
||||
"cannotDownloadFile",
|
||||
]
|
||||
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
|
||||
logger.warning(
|
||||
f"Could not export file '{file['name']}' due to '{message}', skipping..."
|
||||
)
|
||||
continue
|
||||
|
||||
raise
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=file["webViewLink"],
|
||||
sections=[
|
||||
Section(link=file["webViewLink"], text=text_contents)
|
||||
],
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Ran into exception when pulling a file from Google Drive"
|
||||
)
|
||||
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive()
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
yield from self._fetch_docs_from_drive()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
# need to subtract 10 minutes from start time to account for modifiedTime
|
||||
# propogation if a document is modified, it takes some time for the API to
|
||||
# reflect these changes if we do not have an offset, then we may "miss" the
|
||||
# update when polling
|
||||
yield from self._fetch_docs_from_drive(start, end)
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
is_slim=True,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
slim_batch.append(
|
||||
SlimDocument(
|
||||
id=file["webViewLink"],
|
||||
perm_sync_data={
|
||||
"doc_id": file.get("id"),
|
||||
"permissions": file.get("permissions", []),
|
||||
"permission_ids": file.get("permissionIds", []),
|
||||
"name": file.get("name"),
|
||||
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
|
||||
},
|
||||
)
|
||||
)
|
||||
if len(slim_batch) >= SLIM_BATCH_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._extract_slim_docs_from_google_drive(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
import os
|
||||
|
||||
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
|
||||
if not service_account_json_path:
|
||||
raise ValueError(
|
||||
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
|
||||
)
|
||||
with open(service_account_json_path) as f:
|
||||
creds = json.load(f)
|
||||
|
||||
credentials_dict = {
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
|
||||
}
|
||||
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
|
||||
if delegated_user:
|
||||
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
|
||||
|
||||
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
|
||||
connector.load_credentials(credentials_dict)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print(document_batch)
|
||||
break
|
||||
|
||||
@@ -8,16 +8,24 @@ from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR
|
||||
from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS
|
||||
from danswer.connectors.google_drive.constants import BASE_SCOPES
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
@@ -28,14 +36,15 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GOOGLE_DRIVE_SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
]
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_drive_primary_admin"
|
||||
|
||||
def build_gdrive_scopes() -> list[str]:
|
||||
base_scopes: list[str] = BASE_SCOPES
|
||||
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
|
||||
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
return base_scopes + permissions_scopes + groups_scopes
|
||||
return base_scopes + permissions_scopes
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect() -> str:
|
||||
@@ -43,7 +52,7 @@ def _build_frontend_google_drive_redirect() -> str:
|
||||
|
||||
|
||||
def get_google_drive_creds_for_authorized_user(
|
||||
token_json_str: str, scopes: list[str]
|
||||
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
|
||||
@@ -63,15 +72,21 @@ def get_google_drive_creds_for_authorized_user(
|
||||
return None
|
||||
|
||||
|
||||
def _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def get_google_drive_creds(
|
||||
credentials: dict[str, str], scopes: list[str] = GOOGLE_DRIVE_SCOPES
|
||||
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorough
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
@@ -85,27 +100,26 @@ def get_google_drive_creds(
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
|
||||
],
|
||||
}
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
elif KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY]
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
|
||||
service_creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=scopes
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_creds = _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
if not service_creds.valid or not service_creds.expired:
|
||||
service_creds.refresh(Request())
|
||||
|
||||
if not service_creds.valid:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - service account credentials are invalid."
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
service_creds = (
|
||||
service_creds.with_subject(delegated_user_email)
|
||||
if service_creds
|
||||
else None
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
@@ -132,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_DRIVE_SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
@@ -155,34 +169,13 @@ def update_credential_access_tokens(
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=GOOGLE_DRIVE_SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
creds = flow.credentials
|
||||
token_json_str = creds.to_json()
|
||||
|
||||
# Get user email from Google API so we know who
|
||||
# the primary admin is for this connector
|
||||
try:
|
||||
admin_service = build("drive", "v3", credentials=creds)
|
||||
user_info = (
|
||||
admin_service.about()
|
||||
.get(
|
||||
fields="user(emailAddress)",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
email = user_info.get("user", {}).get("emailAddress")
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
}
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
return None
|
||||
@@ -191,15 +184,15 @@ def update_credential_access_tokens(
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
primary_admin_email: str | None = None,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
|
||||
credential_dict = {
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
}
|
||||
if primary_admin_email:
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
|
||||
if delegated_user_email:
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
|
||||
@@ -1,36 +1,7 @@
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
|
||||
FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
|
||||
"shortcutDetails, owners(emailAddress))"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
|
||||
"permissionIds, webViewLink, owners(emailAddress))"
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
# these errors don't represent a failure in the connector, but simply files
|
||||
# that can't / shouldn't be indexed
|
||||
ERRORS_TO_CONTINUE_ON = [
|
||||
"cannotExportFile",
|
||||
"exportSizeLimitExceeded",
|
||||
"cannotDownloadFile",
|
||||
]
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
# Documentation and error messages
|
||||
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
|
||||
ONYX_SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded Danswer without updating the Google Drive scopes. "
|
||||
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
|
||||
)
|
||||
|
||||
# Batch sizes
|
||||
SLIM_BATCH_SIZE = 500
|
||||
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from danswer.connectors.google_drive.constants import ERRORS_TO_CONTINUE_ON
|
||||
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
from danswer.connectors.google_drive.models import GDriveMimeType
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_text(file: dict[str, str], service: Resource) -> str:
|
||||
mime_type = file["mimeType"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(
|
||||
file=io.BytesIO(response), file_name=file.get("name", file["id"])
|
||||
)
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType, service: Resource
|
||||
) -> Document | None:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
return None
|
||||
try:
|
||||
text_contents = _extract_text(file, service) or ""
|
||||
except HttpError as e:
|
||||
reason = e.error_details[0]["reason"] if e.error_details else e.reason
|
||||
message = e.error_details[0]["message"] if e.error_details else e.reason
|
||||
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
|
||||
logger.warning(
|
||||
f"Could not export file '{file['name']}' due to '{message}', skipping..."
|
||||
)
|
||||
return None
|
||||
|
||||
raise
|
||||
|
||||
return Document(
|
||||
id=file["webViewLink"],
|
||||
sections=[Section(link=file["webViewLink"], text=text_contents)],
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
except Exception as e:
|
||||
if not CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
raise e
|
||||
|
||||
logger.exception("Ran into exception when pulling a file from Google Drive")
|
||||
return None
|
||||
@@ -1,192 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from danswer.connectors.google_drive.constants import FILE_FIELDS
|
||||
from danswer.connectors.google_drive.constants import FOLDER_FIELDS
|
||||
from danswer.connectors.google_drive.constants import SLIM_FILE_FIELDS
|
||||
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime >= '{time_start}'"
|
||||
if end is not None:
|
||||
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
def _get_folders_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str | None = None,
|
||||
personal_drive: bool = False,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# Follow shortcuts to folders
|
||||
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
|
||||
query += " and trashed = false"
|
||||
|
||||
if parent_id:
|
||||
query += f" and '{parent_id}' in parents"
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user" if personal_drive else "allDrives",
|
||||
supportsAllDrives=not personal_drive,
|
||||
includeItemsFromAllDrives=not personal_drive,
|
||||
fields=FOLDER_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
personal_drive: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
is_slim: bool = False,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
query += _generate_time_range_filter(start, end)
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user" if personal_drive else "allDrives",
|
||||
supportsAllDrives=not personal_drive,
|
||||
includeItemsFromAllDrives=not personal_drive,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def crawl_folders_for_files(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
personal_drive: bool,
|
||||
traversed_parent_ids: set[str],
|
||||
update_traversed_ids_func: Callable[[str], None],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
if parent_id in traversed_parent_ids:
|
||||
print(f"Skipping subfolder since already traversed: {parent_id}")
|
||||
return
|
||||
|
||||
update_traversed_ids_func(parent_id)
|
||||
|
||||
yield from _get_files_in_parent(
|
||||
service=service,
|
||||
personal_drive=personal_drive,
|
||||
start=start,
|
||||
end=end,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
personal_drive=personal_drive,
|
||||
):
|
||||
logger.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
yield from crawl_folders_for_files(
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
personal_drive=personal_drive,
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
update_traversed_ids_func=update_traversed_ids_func,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
is_slim: bool = False,
|
||||
cache_folders: bool = True,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
if cache_folders:
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
query += " and trashed = false"
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields="nextPageToken, files(id)",
|
||||
q=query,
|
||||
):
|
||||
update_traversed_ids_func(file["id"])
|
||||
|
||||
# Get all files in the shared drive
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
query += " and trashed = false"
|
||||
query += _generate_time_range_filter(start, end)
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def get_files_in_my_drive(
|
||||
service: Resource,
|
||||
email: str,
|
||||
is_slim: bool = False,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners"
|
||||
query += " and trashed = false"
|
||||
query += _generate_time_range_filter(start, end)
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user",
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
# Just in case we need to get the root folder id
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return service.files().get(fileId="root", fields="id").execute()["id"]
|
||||
@@ -1,35 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
"""Execute a paginated retrieval from Google Drive API
|
||||
Args:
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs["pageToken"] = next_page_token
|
||||
|
||||
results = add_retries(lambda: retrieval_function(**request_kwargs).execute())()
|
||||
|
||||
next_page_token = results.get("nextPageToken")
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
@@ -1,18 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
@@ -56,11 +56,7 @@ class PollConnector(BaseConnector):
|
||||
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import functools
|
||||
import itertools
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
@@ -19,8 +18,6 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import itertools
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
@@ -26,8 +25,6 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
|
||||
|
||||
|
||||
def pywikibot_timestamp_to_utc_datetime(
|
||||
timestamp: pywikibot.time.Timestamp,
|
||||
@@ -124,6 +121,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
|
||||
# short names can only have ascii letters and digits
|
||||
|
||||
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
|
||||
self.site = pywikibot.Site(fam=self.family, code=language_code)
|
||||
self.categories = [
|
||||
|
||||
@@ -251,11 +251,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
@@ -391,11 +391,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
from zenpy import Zenpy # type: ignore
|
||||
from zenpy.lib.api_objects import Ticket # type: ignore
|
||||
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
@@ -17,244 +20,43 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
|
||||
|
||||
class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"Zendesk Credentials are not set up, was load_credentials called?"
|
||||
)
|
||||
|
||||
|
||||
class ZendeskClient:
|
||||
def __init__(self, subdomain: str, email: str, token: str):
|
||||
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
|
||||
self.auth = (f"{email}/token", token)
|
||||
|
||||
@retry_builder()
|
||||
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
content_tags: dict[str, str] = {}
|
||||
params = {"page[size]": MAX_PAGE_SIZE}
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = client.make_request("guide/content_tags", params)
|
||||
|
||||
for tag in data.get("records", []):
|
||||
content_tags[tag["id"]] = tag["name"]
|
||||
|
||||
# Check if there are more pages
|
||||
if data.get("meta", {}).get("has_more", False):
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
else:
|
||||
break
|
||||
|
||||
return content_tags
|
||||
except Exception as e:
|
||||
raise Exception(f"Error fetching content tags: {str(e)}")
|
||||
|
||||
|
||||
def _get_articles(
|
||||
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = (
|
||||
{"start_time": start_time, "page[size]": page_size}
|
||||
if start_time
|
||||
else {"page[size]": page_size}
|
||||
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
|
||||
author = BasicExpertInfo(
|
||||
display_name=article.author.name, email=article.author.email
|
||||
)
|
||||
update_time = time_str_to_utc(article.updated_at)
|
||||
|
||||
while True:
|
||||
data = client.make_request("help_center/articles", params)
|
||||
for article in data["articles"]:
|
||||
yield article
|
||||
|
||||
if not data.get("meta", {}).get("has_more"):
|
||||
break
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
|
||||
|
||||
def _get_tickets(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = {"start_time": start_time} if start_time else {"start_time": 0}
|
||||
|
||||
while True:
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
for ticket in data["tickets"]:
|
||||
yield ticket
|
||||
|
||||
if not data.get("end_of_stream", False):
|
||||
params["start_time"] = data["end_time"]
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
author_data = client.make_request(f"users/{author_id}", {})
|
||||
user = author_data.get("user")
|
||||
return (
|
||||
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
|
||||
if user and user.get("name") and user.get("email")
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def _article_to_document(
|
||||
article: dict[str, Any],
|
||||
content_tags: dict[str, str],
|
||||
author_map: dict[str, BasicExpertInfo],
|
||||
client: ZendeskClient,
|
||||
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
|
||||
author_id = article.get("author_id")
|
||||
if not author_id:
|
||||
author = None
|
||||
else:
|
||||
author = (
|
||||
author_map.get(author_id)
|
||||
if author_id in author_map
|
||||
else _fetch_author(client, author_id)
|
||||
)
|
||||
|
||||
new_author_mapping = {author_id: author} if author_id and author else None
|
||||
|
||||
updated_at = article.get("updated_at")
|
||||
update_time = time_str_to_utc(updated_at) if updated_at else None
|
||||
|
||||
# Build metadata
|
||||
# build metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"labels": [str(label) for label in article.get("label_names", []) if label],
|
||||
"labels": [str(label) for label in article.label_names if label],
|
||||
"content_tags": [
|
||||
content_tags[tag_id]
|
||||
for tag_id in article.get("content_tag_ids", [])
|
||||
for tag_id in article.content_tag_ids
|
||||
if tag_id in content_tags
|
||||
],
|
||||
}
|
||||
|
||||
# Remove empty values
|
||||
# remove empty values
|
||||
metadata = {k: v for k, v in metadata.items() if v}
|
||||
|
||||
return new_author_mapping, Document(
|
||||
id=f"article:{article['id']}",
|
||||
return Document(
|
||||
id=f"article:{article.id}",
|
||||
sections=[
|
||||
Section(
|
||||
link=article.get("html_url"),
|
||||
text=parse_html_page_basic(article["body"]),
|
||||
)
|
||||
Section(link=article.html_url, text=parse_html_page_basic(article.body))
|
||||
],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=article["title"],
|
||||
semantic_identifier=article.title,
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=[author] if author else None,
|
||||
primary_owners=[author],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _get_comment_text(
|
||||
comment: dict[str, Any],
|
||||
author_map: dict[str, BasicExpertInfo],
|
||||
client: ZendeskClient,
|
||||
) -> tuple[dict[str, BasicExpertInfo] | None, str]:
|
||||
author_id = comment.get("author_id")
|
||||
if not author_id:
|
||||
author = None
|
||||
else:
|
||||
author = (
|
||||
author_map.get(author_id)
|
||||
if author_id in author_map
|
||||
else _fetch_author(client, author_id)
|
||||
)
|
||||
|
||||
new_author_mapping = {author_id: author} if author_id and author else None
|
||||
|
||||
comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
|
||||
comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
|
||||
|
||||
return new_author_mapping, comment_text
|
||||
|
||||
|
||||
def _ticket_to_document(
|
||||
ticket: dict[str, Any],
|
||||
author_map: dict[str, BasicExpertInfo],
|
||||
client: ZendeskClient,
|
||||
default_subdomain: str,
|
||||
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
|
||||
submitter_id = ticket.get("submitter")
|
||||
if not submitter_id:
|
||||
submitter = None
|
||||
else:
|
||||
submitter = (
|
||||
author_map.get(submitter_id)
|
||||
if submitter_id in author_map
|
||||
else _fetch_author(client, submitter_id)
|
||||
)
|
||||
|
||||
new_author_mapping = (
|
||||
{submitter_id: submitter} if submitter_id and submitter else None
|
||||
)
|
||||
|
||||
updated_at = ticket.get("updated_at")
|
||||
update_time = time_str_to_utc(updated_at) if updated_at else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
if status := ticket.get("status"):
|
||||
metadata["status"] = status
|
||||
if priority := ticket.get("priority"):
|
||||
metadata["priority"] = priority
|
||||
if tags := ticket.get("tags"):
|
||||
metadata["tags"] = tags
|
||||
if ticket_type := ticket.get("type"):
|
||||
metadata["ticket_type"] = ticket_type
|
||||
|
||||
# Fetch comments for the ticket
|
||||
comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
|
||||
comments = comments_data.get("comments", [])
|
||||
|
||||
comment_texts = []
|
||||
for comment in comments:
|
||||
new_author_mapping, comment_text = _get_comment_text(
|
||||
comment, author_map, client
|
||||
)
|
||||
if new_author_mapping:
|
||||
author_map.update(new_author_mapping)
|
||||
comment_texts.append(comment_text)
|
||||
|
||||
comments_text = "\n\n".join(comment_texts)
|
||||
|
||||
subject = ticket.get("subject")
|
||||
full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
|
||||
|
||||
ticket_url = ticket.get("url")
|
||||
subdomain = (
|
||||
ticket_url.split("//")[1].split(".zendesk.com")[0]
|
||||
if ticket_url
|
||||
else default_subdomain
|
||||
)
|
||||
|
||||
ticket_display_url = (
|
||||
f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.get('id')}"
|
||||
)
|
||||
|
||||
return new_author_mapping, Document(
|
||||
id=f"zendesk_ticket_{ticket['id']}",
|
||||
sections=[Section(link=ticket_display_url, text=full_text)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=[submitter] if submitter else None,
|
||||
metadata=metadata,
|
||||
)
|
||||
class ZendeskClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Zendesk Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector):
|
||||
@@ -264,10 +66,44 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
self.zendesk_client: Zenpy | None = None
|
||||
self.content_tags: dict[str, str] = {}
|
||||
self.content_type = content_type
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def _set_content_tags(
|
||||
self, subdomain: str, email: str, token: str, page_size: int = 30
|
||||
) -> None:
|
||||
# Construct the base URL
|
||||
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
|
||||
|
||||
# Set up authentication
|
||||
auth = (f"{email}/token", token)
|
||||
|
||||
# Set up pagination parameters
|
||||
params = {"page[size]": page_size}
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Make the GET request
|
||||
response = requests.get(base_url, auth=auth, params=params)
|
||||
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
content_tag_list = data.get("records", [])
|
||||
for tag in content_tag_list:
|
||||
self.content_tags[tag["id"]] = tag["name"]
|
||||
|
||||
# Check if there are more pages
|
||||
if data.get("meta", {}).get("has_more", False):
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code}\n{response.text}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error fetching content tags: {str(e)}")
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Subdomain is actually the whole URL
|
||||
@@ -276,23 +112,87 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
.replace("https://", "")
|
||||
.split(".zendesk.com")[0]
|
||||
)
|
||||
self.subdomain = subdomain
|
||||
|
||||
self.client = ZendeskClient(
|
||||
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
|
||||
self.zendesk_client = Zenpy(
|
||||
subdomain=subdomain,
|
||||
email=credentials["zendesk_email"],
|
||||
token=credentials["zendesk_token"],
|
||||
)
|
||||
self._set_content_tags(
|
||||
subdomain,
|
||||
credentials["zendesk_email"],
|
||||
credentials["zendesk_token"],
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def _ticket_to_document(self, ticket: Ticket) -> Document:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
owner = None
|
||||
if ticket.requester and ticket.requester.name and ticket.requester.email:
|
||||
owner = [
|
||||
BasicExpertInfo(
|
||||
display_name=ticket.requester.name, email=ticket.requester.email
|
||||
)
|
||||
]
|
||||
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
if ticket.status is not None:
|
||||
metadata["status"] = ticket.status
|
||||
if ticket.priority is not None:
|
||||
metadata["priority"] = ticket.priority
|
||||
if ticket.tags:
|
||||
metadata["tags"] = ticket.tags
|
||||
if ticket.type is not None:
|
||||
metadata["ticket_type"] = ticket.type
|
||||
|
||||
# Fetch comments for the ticket
|
||||
comments = self.zendesk_client.tickets.comments(ticket=ticket)
|
||||
|
||||
# Combine all comments into a single text
|
||||
comments_text = "\n\n".join(
|
||||
[
|
||||
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
|
||||
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
|
||||
for comment in comments
|
||||
if comment.body
|
||||
]
|
||||
)
|
||||
|
||||
# Combine ticket description and comments
|
||||
description = (
|
||||
ticket.description
|
||||
if hasattr(ticket, "description") and ticket.description
|
||||
else ""
|
||||
)
|
||||
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
|
||||
|
||||
# Extract subdomain from ticket.url
|
||||
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
|
||||
|
||||
# Build the html url for the ticket
|
||||
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
|
||||
|
||||
return Document(
|
||||
id=f"zendesk_ticket_{ticket.id}",
|
||||
sections=[Section(link=ticket_url, text=full_text)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=owner,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
self.content_tags = _get_content_tag_mapping(self.client)
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
if self.content_type == "articles":
|
||||
yield from self._poll_articles(start)
|
||||
@@ -304,30 +204,26 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = _get_articles(self.client, start_time=int(start) if start else None)
|
||||
|
||||
# This one is built on the fly as there may be more many more authors than tags
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
|
||||
articles = (
|
||||
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
|
||||
if start is None
|
||||
else self.zendesk_client.help_center.articles.incremental( # type: ignore
|
||||
start_time=int(start)
|
||||
)
|
||||
)
|
||||
doc_batch = []
|
||||
for article in articles:
|
||||
if (
|
||||
article.get("body") is None
|
||||
or article.get("draft")
|
||||
article.body is None
|
||||
or article.draft
|
||||
or any(
|
||||
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
for label in article.get("label_names", [])
|
||||
for label in article.label_names
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
new_author_map, documents = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
doc_batch.append(_article_to_document(article, self.content_tags))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
@@ -338,14 +234,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
|
||||
ticket_generator = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
@@ -354,20 +246,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.get("status") == "deleted":
|
||||
if ticket.status == "deleted":
|
||||
continue
|
||||
|
||||
new_author_map, documents = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
doc_batch.append(self._ticket_to_document(ticket))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
@@ -385,6 +267,7 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
|
||||
@@ -57,10 +57,10 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from ee.danswer.db.api_key import get_api_key_email_pattern
|
||||
|
||||
|
||||
def get_default_admin_user_emails() -> list[str]:
|
||||
@@ -36,16 +35,12 @@ def get_default_admin_user_emails() -> list[str]:
|
||||
return get_default_admin_user_emails_fn()
|
||||
|
||||
|
||||
def get_total_users_count(db_session: Session) -> int:
|
||||
def get_total_users(db_session: Session) -> int:
|
||||
"""
|
||||
Returns the total number of users in the system.
|
||||
This is the sum of users and invited users.
|
||||
"""
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
|
||||
.count()
|
||||
)
|
||||
user_count = db_session.query(User).count()
|
||||
invited_users = len(get_invited_users())
|
||||
return user_count + invited_users
|
||||
|
||||
|
||||
@@ -388,7 +388,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -474,7 +474,7 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
@@ -494,7 +494,7 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_call = tool_call
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
@@ -513,7 +513,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_call=tool_call,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
@@ -749,13 +749,14 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
@@ -10,10 +10,12 @@ from sqlalchemy.sql.expression import or_
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Credential__UserGroup
|
||||
@@ -440,7 +442,7 @@ def delete_google_drive_service_account_credentials(
|
||||
) -> None:
|
||||
credentials = fetch_credentials(db_session=db_session, user=user)
|
||||
for credential in credentials:
|
||||
if credential.credential_json.get(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY):
|
||||
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY):
|
||||
db_session.delete(credential)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -37,10 +37,10 @@ from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -322,18 +322,11 @@ async def get_async_session_with_tenant(
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session bound to a connection with the appropriate tenant schema set.
|
||||
This preserves the tenant ID across the session and reverts to the previous tenant ID
|
||||
after the session is closed.
|
||||
"""
|
||||
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Store the previous tenant ID
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = previous_tenant_id
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
else:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
@@ -342,35 +335,30 @@ def get_session_with_tenant(
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
try:
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
yield session
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
finally:
|
||||
# Restore the previous tenant ID
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def set_search_path_on_checkout(
|
||||
|
||||
@@ -734,10 +734,9 @@ class IndexAttempt(Base):
|
||||
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
# Nullable because in the past, we didn't allow swapping out embedding models live
|
||||
search_settings_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("search_settings.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
ForeignKey("search_settings.id"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -757,7 +756,7 @@ class IndexAttempt(Base):
|
||||
"ConnectorCredentialPair", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
search_settings: Mapped[SearchSettings | None] = relationship(
|
||||
search_settings: Mapped[SearchSettings] = relationship(
|
||||
"SearchSettings", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
@@ -918,15 +917,10 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id"), nullable=False
|
||||
)
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
# Update the relationship
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage",
|
||||
back_populates="tool_call",
|
||||
uselist=False,
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
)
|
||||
|
||||
|
||||
@@ -1057,13 +1051,12 @@ class ChatMessage(Base):
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_search_settings_status
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -22,14 +23,7 @@ logger = setup_logger()
|
||||
def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
"""Get count of cc-pairs and count of successful index_attempts for the
|
||||
new model grouped by connector + credential, if it's the same, then assume
|
||||
new index is done building. If so, swap the indices and expire the old one.
|
||||
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
|
||||
old_search_settings = None
|
||||
|
||||
new index is done building. If so, swap the indices and expire the old one."""
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
@@ -49,9 +43,9 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
|
||||
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
now_old_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
search_settings=now_old_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -73,6 +67,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
old_search_settings = current_search_settings
|
||||
|
||||
return old_search_settings
|
||||
if MULTI_TENANT:
|
||||
return now_old_search_settings
|
||||
return None
|
||||
|
||||
@@ -13,7 +13,6 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
CSV = "csv"
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
|
||||
@@ -8,13 +8,12 @@ import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
@@ -53,11 +52,11 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -76,10 +75,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
|
||||
|
||||
def save_files_from_urls(urls: list[str]) -> list[str]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
|
||||
(save_file_from_url, (url, tenant_id)) for url in urls
|
||||
(save_file_from_url, (url,)) for url in urls
|
||||
]
|
||||
# Must pass in tenant_id here, since this is called by multithreading
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
||||
@@ -16,9 +16,9 @@ from danswer.key_value_store.interface import KeyValueStore
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -1,44 +1,72 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
|
||||
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
build_citation_processor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
build_quotes_processor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.force import filter_tools_for_force_tool_use
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.images.prompt import build_image_generation_user_prompt
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -46,9 +74,29 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_answer_stream_processor(
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
answer_style_configs: AnswerStyleConfig,
|
||||
) -> StreamProcessor:
|
||||
if answer_style_configs.citation_config:
|
||||
return build_citation_processor(
|
||||
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
|
||||
)
|
||||
if answer_style_configs.quotes_config:
|
||||
return build_quotes_processor(
|
||||
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
|
||||
)
|
||||
|
||||
raise RuntimeError("Not implemented yet")
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -88,6 +136,8 @@ class Answer:
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
@@ -112,141 +162,335 @@ class Answer:
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.using_tool_calling_llm = (
|
||||
explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
if self.answer_style_config.citation_config:
|
||||
prompt_builder.update_system_prompt(
|
||||
build_citations_system_message(self.prompt_config)
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool = next(
|
||||
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
|
||||
)
|
||||
if tool is None:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
|
||||
logger.info(
|
||||
f"Forcefully using tool='{tool.name}'"
|
||||
+ (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args is not None
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return [tool]
|
||||
|
||||
def _handle_specified_tool_call(
|
||||
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# make a dummy tool handler
|
||||
tool_handler = ToolResponseHandler([tool])
|
||||
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = next(
|
||||
(t for t in current_llm_call.tools if t.name == tool_name), None
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
if not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
prompt_builder.update_user_prompt(
|
||||
build_citations_user_message(
|
||||
question=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
context_docs=final_context_documents,
|
||||
files=self.latest_query_files,
|
||||
all_doc_useful=(
|
||||
self.answer_style_config.citation_config.all_docs_useful
|
||||
if self.answer_style_config.citation_config
|
||||
else False
|
||||
),
|
||||
history_message=self.single_message_history or "",
|
||||
)
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_quotes_user_message(
|
||||
question=self.question,
|
||||
context_docs=final_context_documents,
|
||||
history_str=self.single_message_history or "",
|
||||
prompt=self.prompt_config,
|
||||
)
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
def _raw_output_for_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
content="",
|
||||
)
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
else:
|
||||
# if tool calling is supported, first try the raw message
|
||||
# to see if we don't need to use any tools
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if not tool_call_chunk:
|
||||
return # no tool call needed
|
||||
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
# as of now, we don't support multiple tool calls in sequence, which is why
|
||||
# we don't need to pass this in here
|
||||
# tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
# This method processes the LLM stream and yields the content or stop information
|
||||
def _process_llm_stream(
|
||||
self,
|
||||
prompt: Any,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[str | StreamStopInfo]:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield cast(str, message.content)
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)
|
||||
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
def _raw_output_for_non_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
if self.skip_gen_ai_answer_generation:
|
||||
raise ValueError(
|
||||
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
yield response
|
||||
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("No answer style config provided")
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
yield final
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -254,30 +498,94 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
single_message_history=self.single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
output_generator = (
|
||||
self._raw_output_for_explicit_tool_calling_llms()
|
||||
if explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
)
|
||||
and not self.skip_explicit_tool_calling
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
)
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
|
||||
) -> AnswerStream:
|
||||
message = None
|
||||
|
||||
# special things we need to keep track of for the SearchTool
|
||||
# raw results that will be displayed to the user
|
||||
search_results: list[LlmDoc] | None = None
|
||||
# processed docs to feed into the LLM
|
||||
final_context_docs: list[LlmDoc] | None = None
|
||||
|
||||
for message in stream:
|
||||
if isinstance(message, ToolCallKickoff) or isinstance(
|
||||
message, ToolCallFinalResult
|
||||
):
|
||||
yield message
|
||||
elif isinstance(message, ToolResponse):
|
||||
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
# We don't need to run section merging in this flow, this variable is only used
|
||||
# below to specify the ordering of the documents for the purpose of matching
|
||||
# citations to the right search documents. The deduplication logic is more lightweight
|
||||
# there and we don't need to do it twice
|
||||
search_results = [
|
||||
llm_doc_from_inference_section(section)
|
||||
for section in cast(
|
||||
SearchResponseSummary, message.response
|
||||
).top_sections
|
||||
]
|
||||
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_docs = cast(list[LlmDoc], message.response)
|
||||
yield message
|
||||
|
||||
elif (
|
||||
message.id == SEARCH_DOC_CONTENT_ID
|
||||
and not self._return_contexts
|
||||
):
|
||||
continue
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
stream_stop_info = None
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
for item in itertools.chain([message], stream):
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
|
||||
# this should never happen, but we're seeing weird behavior here so handling for now
|
||||
if not isinstance(item, str):
|
||||
logger.error(
|
||||
f"Received non-string item in answer stream: {item}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
@@ -301,6 +609,7 @@ class Answer:
|
||||
|
||||
return citations
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
return
|
||||
# tool handler doesn't do anything until the full message is received
|
||||
# NOTE: still need to run list() to get this to run
|
||||
list(self.tool_handler.handle_response_part(message, all_messages))
|
||||
yield from self.answer_handler.handle_response_part(message, all_messages)
|
||||
all_messages.append(message)
|
||||
|
||||
# potentially give back all info on the selected tool call + its result
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
||||
@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -51,13 +51,14 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -12,12 +12,12 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
@@ -54,14 +54,18 @@ def default_build_user_message(
|
||||
|
||||
class AnswerPromptBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
user_message: HumanMessage,
|
||||
message_history: list[PreviousMessage],
|
||||
llm_config: LLMConfig,
|
||||
single_message_history: str | None = None,
|
||||
self, message_history: list[PreviousMessage], llm_config: LLMConfig
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
@@ -70,24 +74,6 @@ class AnswerPromptBuilder:
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
self.raw_message_history = message_history
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
# for cases where like the QA flow where we want to condense the chat history
|
||||
# into a single message rather than a sequence of User / Assistant messages
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
self.system_message_and_token_cnt = None
|
||||
@@ -99,21 +85,18 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
def update_user_prompt(self, user_message: HumanMessage) -> None:
|
||||
if not user_message:
|
||||
self.user_message_and_token_cnt = None
|
||||
return
|
||||
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
def append_message(self, message: BaseMessage) -> None:
|
||||
"""Append a new message to the message history."""
|
||||
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
|
||||
self.new_messages_and_token_cnts.append((message, token_count))
|
||||
|
||||
def get_user_message_content(self) -> str:
|
||||
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
|
||||
return query
|
||||
|
||||
def build(self) -> list[BaseMessage]:
|
||||
def build(
|
||||
self, tool_call_summary: ToolCallSummary | None = None
|
||||
) -> list[BaseMessage]:
|
||||
if not self.user_message_and_token_cnt:
|
||||
raise ValueError("User message must be set before building prompt")
|
||||
|
||||
@@ -130,8 +113,25 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
if tool_call_summary:
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_request,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_request,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_result,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_result,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -6,6 +6,7 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MA
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
@@ -13,7 +14,6 @@ from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
@@ -132,9 +132,10 @@ def build_citations_system_message(
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
message: HumanMessage,
|
||||
question: str,
|
||||
prompt_config: PromptConfig,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
files: list[InMemoryChatFile],
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
) -> HumanMessage:
|
||||
@@ -148,7 +149,6 @@ def build_citations_user_message(
|
||||
if history_message
|
||||
else ""
|
||||
)
|
||||
query, img_urls = message_to_prompt_and_imgs(message)
|
||||
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
@@ -158,22 +158,20 @@ def build_citations_user_message(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
user_query=question,
|
||||
history_block=history_block,
|
||||
)
|
||||
else:
|
||||
# if no context docs provided, assume we're in the tool calling flow
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
user_query=question,
|
||||
history_block=history_block,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
|
||||
if img_urls
|
||||
else user_prompt
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
|
||||
return user_msg
|
||||
|
||||
@@ -5,7 +5,6 @@ from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
@@ -76,7 +75,7 @@ def _build_strong_llm_quotes_prompt(
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
message: HumanMessage,
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
@@ -87,10 +86,28 @@ def build_quotes_user_message(
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
return prompt_builder(
|
||||
question=query,
|
||||
question=question,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def build_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
if QA_PROMPT_OVERRIDE == "weak"
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
return prompt_builder(
|
||||
question=question,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
|
||||
@@ -19,7 +19,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
|
||||
from danswer.tools.search.search_utils import section_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
|
||||
class AnswerResponseHandler(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# Process the new content through the citation processor
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
yield from self.quotes_processor.process_token(content)
|
||||
@@ -1,10 +1,12 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -17,104 +19,128 @@ def in_code_block(llm_text: str) -> bool:
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
self.current_citations: list[int] = []
|
||||
self.past_cite_count = 0
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
"""
|
||||
Key aspects:
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
yield DanswerAnswerPiece(answer_piece=self.curr_segment)
|
||||
return
|
||||
1. Stream Processing:
|
||||
- Processes tokens one by one, allowing for real-time handling of large texts.
|
||||
|
||||
if self.stop_stream:
|
||||
next_hold = self.hold + token
|
||||
if self.stop_stream in next_hold:
|
||||
return
|
||||
if next_hold == self.stop_stream[: len(next_hold)]:
|
||||
self.hold = next_hold
|
||||
return
|
||||
2. Citation Detection:
|
||||
- Uses regex to find citations in the format [number].
|
||||
- Example: [1], [2], etc.
|
||||
|
||||
3. Citation Mapping:
|
||||
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
|
||||
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
|
||||
|
||||
4. Citation Formatting:
|
||||
- Replaces citations with properly formatted versions.
|
||||
- Adds links if available: [[1]](https://example.com)
|
||||
- Handles cases where links are not available: [[1]]()
|
||||
|
||||
5. Duplicate Handling:
|
||||
- Skips consecutive citations of the same document to avoid redundancy.
|
||||
|
||||
6. Output Generation:
|
||||
- Yields DanswerAnswerPiece objects for regular text.
|
||||
- Yields CitationInfo objects for each unique citation encountered.
|
||||
|
||||
7. Context Awareness:
|
||||
- Uses context_docs to access document information for citations.
|
||||
|
||||
This function effectively processes a stream of text, identifies and reformats citations,
|
||||
and provides both the processed text and citation information as output.
|
||||
"""
|
||||
order_mapping = doc_id_to_rank_map.order_mapping
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
citation_order = []
|
||||
curr_segment = ""
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
|
||||
raw_out = ""
|
||||
current_citations: list[int] = []
|
||||
past_cite_count = 0
|
||||
for raw_token in tokens:
|
||||
raw_out += raw_token
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
token = next_hold
|
||||
self.hold = ""
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
|
||||
self.curr_segment += token
|
||||
self.llm_out += token
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
# Handle code blocks without language tags
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
return
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
if "`" in curr_segment:
|
||||
if curr_segment.endswith("`"):
|
||||
continue
|
||||
elif "```" in curr_segment:
|
||||
piece_that_comes_after = curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(llm_out):
|
||||
curr_segment = curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
|
||||
citations_found = list(re.finditer(citation_pattern, curr_segment))
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
# `past_cite_count`: number of characters since past citation
|
||||
# 5 to ensure a citation hasn't occured
|
||||
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
|
||||
current_citations = []
|
||||
|
||||
result = "" # Initialize result here
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
if citations_found and not in_code_block(llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[numerical_value - 1]
|
||||
real_citation_num = order_mapping[context_llm_doc.document_id]
|
||||
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
if real_citation_num not in citation_order:
|
||||
citation_order.append(real_citation_num)
|
||||
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
target_citation_num = citation_order.index(real_citation_num) + 1
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
if target_citation_num in current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
curr_segment = (
|
||||
curr_segment[: length_to_add + start]
|
||||
+ curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
# by allowing it to generate citations on its own.
|
||||
if curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
context_llm_doc = context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
@@ -124,57 +150,75 @@ class CitationProcessor:
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
# Will continue attempt on next loops
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
past_cite_count = len(llm_out)
|
||||
current_citations.append(target_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
self.curr_segment = self.curr_segment[last_citation_end:]
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
|
||||
curr_segment = curr_segment[last_citation_end:]
|
||||
if possible_citation_found:
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if not possible_citation_found:
|
||||
result += self.curr_segment
|
||||
self.curr_segment = ""
|
||||
if curr_segment:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
if result:
|
||||
yield DanswerAnswerPiece(answer_piece=result)
|
||||
|
||||
def build_citation_processor(
|
||||
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
) -> StreamProcessor:
|
||||
def stream_processor(
|
||||
tokens: Iterator[str],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from extract_citations_from_stream(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
return stream_processor
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
@@ -154,7 +157,7 @@ def separate_answer_quotes(
|
||||
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
|
||||
|
||||
|
||||
def _process_answer(
|
||||
def process_answer(
|
||||
answer_raw: str,
|
||||
docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
@@ -192,7 +195,7 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
def _extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
|
||||
) -> DanswerQuotes:
|
||||
answer, quotes = _process_answer(model_output, context_docs, is_json_prompt)
|
||||
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
|
||||
if answer:
|
||||
logger.notice(answer)
|
||||
elif model_output:
|
||||
@@ -201,101 +204,94 @@ def _extract_quotes_from_completed_token_stream(
|
||||
return quotes
|
||||
|
||||
|
||||
class QuotesProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.is_json_prompt = is_json_prompt
|
||||
def process_model_tokens(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
"""Used in the streaming case to process the model output
|
||||
into an Answer and Quotes
|
||||
|
||||
self.found_answer_start = False if is_json_prompt else True
|
||||
self.found_answer_end = False
|
||||
self.hold_quote = ""
|
||||
self.model_output = ""
|
||||
self.hold = ""
|
||||
Yields Answer tokens back out in a dict for streaming to frontend
|
||||
When Answer section ends, yields dict with answer_finished key
|
||||
Collects all the tokens at the end to form the complete model output"""
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
# Sometimes worse model outputs new line instead of :
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
# Sometime model outputs two newlines before quote section
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
model_output: str = ""
|
||||
found_answer_start = False if is_json_prompt else True
|
||||
found_answer_end = False
|
||||
hold_quote = ""
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
if self.model_output:
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=self.model_output,
|
||||
context_docs=self.context_docs,
|
||||
is_json_prompt=self.is_json_prompt,
|
||||
)
|
||||
return
|
||||
for token in tokens:
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
model_previous = self.model_output
|
||||
self.model_output += token
|
||||
|
||||
if not self.found_answer_start:
|
||||
m = answer_pattern.search(self.model_output)
|
||||
if not found_answer_start:
|
||||
m = answer_pattern.search(model_output)
|
||||
if m:
|
||||
self.found_answer_start = True
|
||||
found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations
|
||||
if self.is_json_prompt and len(self.model_output) > 70:
|
||||
# Prevent heavy cases of hallucinations where model is never providing a JSON
|
||||
# We want to quickly update the user - not stream forever
|
||||
if is_json_prompt and len(model_output) > 70:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
self.found_answer_end = True
|
||||
return
|
||||
found_answer_end = True
|
||||
continue
|
||||
|
||||
remaining = self.model_output[m.end() :]
|
||||
|
||||
# Look for an unescaped quote, which means the answer is entirely contained
|
||||
# in this token e.g. if the token is `{"answer": "blah", "qu`
|
||||
quote_indices = [i for i, char in enumerate(remaining) if char == '"']
|
||||
for quote_idx in quote_indices:
|
||||
# Check if quote is escaped by counting backslashes before it
|
||||
num_backslashes = 0
|
||||
pos = quote_idx - 1
|
||||
while pos >= 0 and remaining[pos] == "\\":
|
||||
num_backslashes += 1
|
||||
pos -= 1
|
||||
# If even number of backslashes, quote is not escaped
|
||||
if num_backslashes % 2 == 0:
|
||||
yield DanswerAnswerPiece(answer_piece=remaining[:quote_idx])
|
||||
return
|
||||
|
||||
# If no unescaped quote found, yield the remaining string
|
||||
remaining = model_output[m.end() :]
|
||||
if len(remaining) > 0:
|
||||
yield DanswerAnswerPiece(answer_piece=remaining)
|
||||
return
|
||||
continue
|
||||
|
||||
if self.found_answer_start and not self.found_answer_end:
|
||||
if self.is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
self.found_answer_end = True
|
||||
if found_answer_start and not found_answer_end:
|
||||
if is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
found_answer_end = True
|
||||
|
||||
# return the remaining part of the answer e.g. token might be 'd.", ' and we should yield 'd.'
|
||||
if token:
|
||||
try:
|
||||
answer_token_section = token.index('"')
|
||||
yield DanswerAnswerPiece(
|
||||
answer_piece=self.hold_quote + token[:answer_token_section]
|
||||
answer_piece=hold_quote + token[:answer_token_section]
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("Quotation mark not found in token")
|
||||
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
return
|
||||
|
||||
elif not self.is_json_prompt:
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
|
||||
if (
|
||||
quote_pat in self.hold_quote + token
|
||||
or quote_loose in self.hold_quote + token
|
||||
):
|
||||
self.found_answer_end = True
|
||||
continue
|
||||
elif not is_json_prompt:
|
||||
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
|
||||
found_answer_end = True
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
return
|
||||
if self.hold_quote + token in quote_pat_full:
|
||||
self.hold_quote += token
|
||||
return
|
||||
continue
|
||||
if hold_quote + token in quote_pat_full:
|
||||
hold_quote += token
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
hold_quote = ""
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
self.hold_quote = ""
|
||||
logger.debug(f"Raw Model QnA Output: {model_output}")
|
||||
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=model_output,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
|
||||
def build_quotes_processor(
|
||||
context_docs: list[LlmDoc], is_json_prompt: bool
|
||||
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
|
||||
def stream_processor(
|
||||
tokens: Iterator[str],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
return stream_processor
|
||||
|
||||
@@ -1,207 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolResponseHandler:
|
||||
def __init__(self, tools: list[Tool]):
|
||||
self.tools = tools
|
||||
|
||||
self.tool_call_chunk: AIMessageChunk | None = None
|
||||
self.tool_call_requests: list[ToolCall] = []
|
||||
|
||||
self.tool_runner: ToolRunner | None = None
|
||||
self.tool_call_summary: ToolCallSummary | None = None
|
||||
|
||||
self.tool_kickoff: ToolCallKickoff | None = None
|
||||
self.tool_responses: list[ToolResponse] = []
|
||||
self.tool_final_result: ToolCallFinalResult | None = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_call_for_non_tool_calling_llm(
|
||||
cls, llm_call: LLMCall, llm: LLM
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
(
|
||||
t
|
||||
for t in llm_call.tools
|
||||
if t.name == llm_call.force_use_tool.tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(llm_call.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
|
||||
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
||||
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
return
|
||||
|
||||
self.tool_call_requests = self.tool_call_chunk.tool_calls
|
||||
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in self.tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
|
||||
if selected_tool and selected_tool_call_request:
|
||||
break
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
return
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
|
||||
self.tool_kickoff = self.tool_runner.kickoff()
|
||||
yield self.tool_kickoff
|
||||
|
||||
for response in self.tool_runner.tool_responses():
|
||||
self.tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
self.tool_final_result = self.tool_runner.tool_final_result()
|
||||
yield self.tool_final_result
|
||||
|
||||
self.tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=self.tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
selected_tool_call_request, self.tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self._handle_tool_call()
|
||||
|
||||
if isinstance(response_item, AIMessageChunk) and (
|
||||
response_item.tool_call_chunks or response_item.tool_calls
|
||||
):
|
||||
if self.tool_call_chunk is None:
|
||||
self.tool_call_chunk = response_item
|
||||
else:
|
||||
self.tool_call_chunk += response_item # type: ignore
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
or self.tool_call_summary is None
|
||||
or self.tool_kickoff is None
|
||||
or self.tool_final_result is None
|
||||
):
|
||||
return None
|
||||
|
||||
tool_runner = self.tool_runner
|
||||
new_prompt_builder = tool_runner.tool.build_next_prompt(
|
||||
prompt_builder=current_llm_call.prompt_builder,
|
||||
tool_call_summary=self.tool_call_summary,
|
||||
tool_responses=self.tool_responses,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
)
|
||||
return LLMCall(
|
||||
prompt_builder=new_prompt_builder,
|
||||
tools=[], # for now, only allow one tool call per response
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=False,
|
||||
tool_name="",
|
||||
args=None,
|
||||
),
|
||||
files=current_llm_call.files,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
tool_call_info=[
|
||||
self.tool_kickoff,
|
||||
*self.tool_responses,
|
||||
self.tool_final_result,
|
||||
],
|
||||
)
|
||||
@@ -83,10 +83,8 @@ def _convert_litellm_message_to_langchain_message(
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
for tool_call in (tool_calls if tool_calls else [])
|
||||
],
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
@@ -8,7 +7,6 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
@@ -137,18 +135,6 @@ def translate_history_to_basemessages(
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
def _process_csv_file(file: InMemoryChatFile) -> str:
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
csv_preview = df.head().to_string()
|
||||
|
||||
file_name_section = (
|
||||
f"CSV FILE NAME: {file.filename}\n"
|
||||
if file.filename
|
||||
else "CSV FILE (NO NAME PROVIDED):\n"
|
||||
)
|
||||
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
|
||||
|
||||
|
||||
def _build_content(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
@@ -159,26 +145,16 @@ def _build_content(
|
||||
if files
|
||||
else None
|
||||
)
|
||||
|
||||
csv_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.CSV]
|
||||
if files
|
||||
else None
|
||||
)
|
||||
|
||||
if not text_files and not csv_files:
|
||||
if not text_files:
|
||||
return message
|
||||
|
||||
final_message_with_files = "FILES:\n\n"
|
||||
for file in text_files or []:
|
||||
for file in text_files:
|
||||
file_content = file.content.decode("utf-8")
|
||||
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
||||
final_message_with_files += (
|
||||
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
||||
)
|
||||
for file in csv_files or []:
|
||||
final_message_with_files += _process_csv_file(file)
|
||||
|
||||
final_message_with_files += message
|
||||
|
||||
return final_message_with_files
|
||||
@@ -227,28 +203,6 @@ def build_content_with_imgs(
|
||||
)
|
||||
|
||||
|
||||
def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]:
|
||||
if isinstance(message.content, str):
|
||||
return message.content, []
|
||||
|
||||
imgs = []
|
||||
texts = []
|
||||
for part in message.content:
|
||||
if isinstance(part, dict):
|
||||
if part.get("type") == "image_url":
|
||||
img_url = part.get("image_url", {}).get("url")
|
||||
if img_url:
|
||||
imgs.append(img_url)
|
||||
elif part.get("type") == "text":
|
||||
text = part.get("text")
|
||||
if text:
|
||||
texts.append(text)
|
||||
else:
|
||||
texts.append(part)
|
||||
|
||||
return "".join(texts), imgs
|
||||
|
||||
|
||||
def dict_based_prompt_to_langchain_prompt(
|
||||
messages: list[dict[str, str]]
|
||||
) -> list[BaseMessage]:
|
||||
|
||||
@@ -52,16 +52,12 @@ from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephr
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
@@ -206,33 +202,30 @@ def stream_answer_objects(
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type
|
||||
),
|
||||
evaluation_type=LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type,
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import redis
|
||||
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id: int = id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
|
||||
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
|
||||
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
|
||||
|
||||
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
|
||||
return RedisConnectorIndex(
|
||||
self.tenant_id, self.id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[2]
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
@@ -1,97 +0,0 @@
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class is used to scan documents by cc_pair in the db and collect them into
|
||||
a unified set for syncing.
|
||||
|
||||
It differs from the other redis helpers in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
@@ -1,145 +0,0 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
|
||||
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
||||
|
||||
|
||||
class RedisConnectorDelete:
|
||||
"""Manages interactions with redis for deletion tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
# todo: move into fence
|
||||
remaining = cast(int, self.redis.scard(self.taskset_key))
|
||||
return remaining
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorDeletionFenceData | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(self, payload: RedisConnectorDeletionFenceData | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "connectordeletion_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
|
||||
return f"{self.PREFIX}_{self.id}_{uuid4()}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
custom_task_id = self._generate_task_id()
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=self.tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorDelete.TASKSET_PREFIX}_{id}"
|
||||
r.srem(taskset_key, task_id)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -1,146 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorIndex:
|
||||
"""Manages interactions with redis for indexing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectorindexing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorindexing_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorindexing_generator_complete
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.generator_progress_key = (
|
||||
f"{self.GENERATOR_PROGRESS_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_complete_key = (
|
||||
f"{self.GENERATOR_COMPLETE_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
return f"{cls.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
|
||||
def generate_generator_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "connectorindexing+generator_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self.id}/{self.search_settings_id}_{uuid4()}"
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorIndexingFenceData | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorIndexingFenceData | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_progress(self) -> int | None:
|
||||
"""Returns None if the key doesn't exist. The"""
|
||||
# TODO: move into fence?
|
||||
bytes = self.redis.get(self.generator_progress_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
|
||||
progress = int(cast(int, bytes))
|
||||
return progress
|
||||
|
||||
def get_completion(self) -> int | None:
|
||||
# TODO: move into fence?
|
||||
bytes = self.redis.get(self.generator_complete_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
|
||||
status = int(cast(int, bytes))
|
||||
return status
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -1,171 +0,0 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
|
||||
|
||||
class RedisConnectorPrune:
|
||||
"""Manages interactions with redis for pruning tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpruning+generator
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorpruning_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorpruning_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
|
||||
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
|
||||
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
|
||||
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
# todo: move into fence
|
||||
remaining = cast(int, self.redis.scard(self.taskset_key))
|
||||
return remaining
|
||||
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active pruning tasks"""
|
||||
count = 0
|
||||
for key in self.redis.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
pruning tasks to be processed ... just after the generator completes."""
|
||||
fence_bytes = self.redis.get(self.generator_complete_key)
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_int = cast(int, fence_bytes)
|
||||
return fence_int
|
||||
|
||||
@generator_complete.setter
|
||||
def generator_complete(self, payload: int | None) -> None:
|
||||
"""Set the payload to an int to set the fence, otherwise if None it will
|
||||
be deleted"""
|
||||
if payload is None:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
documents_to_prune: set[str],
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=self.tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}"
|
||||
r.srem(taskset_key, task_id)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -1,34 +0,0 @@
|
||||
import redis
|
||||
|
||||
|
||||
class RedisConnectorStop:
|
||||
"""Manages interactions with redis for stop signaling. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
FENCE_PREFIX = "connectorstop_fence"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id: int = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -1,99 +0,0 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, payload: int | None) -> None:
|
||||
if payload is None:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload)
|
||||
|
||||
@property
|
||||
def payload(self) -> int | None:
|
||||
bytes = self.redis.get(self.fence_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
|
||||
progress = int(cast(int, bytes))
|
||||
return progress
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -1,91 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: str):
|
||||
self._tenant_id: str | None = tenant_id
|
||||
self._id: str = id
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[2]
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
@@ -1,112 +0,0 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, payload: int | None) -> None:
|
||||
if payload is None:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload)
|
||||
|
||||
@property
|
||||
def payload(self) -> int | None:
|
||||
bytes = self.redis.get(self.fence_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
|
||||
progress = int(cast(int, bytes))
|
||||
return progress
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -10,7 +10,6 @@ from danswer.auth.users import current_user
|
||||
from danswer.auth.users import current_user_with_expired_token
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||
from ee.danswer.auth.users import current_cloud_superuser
|
||||
from ee.danswer.server.tenants.access import control_plane_dep
|
||||
|
||||
|
||||
@@ -101,7 +100,6 @@ def check_router_auth(
|
||||
or depends_fn == api_key_dep
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
):
|
||||
found_auth = True
|
||||
break
|
||||
|
||||
@@ -11,6 +11,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
@@ -37,7 +39,6 @@ from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
@@ -96,6 +97,8 @@ def get_cc_pair_full_info(
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> CCPairFullInfo:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
)
|
||||
@@ -131,9 +134,9 @@ def get_cc_pair_full_info(
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
rci = RedisConnectorIndexing(
|
||||
cc_pair_id=cc_pair_id, search_settings_id=search_settings.id
|
||||
)
|
||||
|
||||
return CCPairFullInfo.from_models(
|
||||
cc_pair_model=cc_pair,
|
||||
@@ -150,7 +153,7 @@ def get_cc_pair_full_info(
|
||||
),
|
||||
num_docs_indexed=documents_indexed,
|
||||
is_editable_for_current_user=is_editable_for_current_user,
|
||||
indexing=redis_connector_index.fenced,
|
||||
indexing=rci.is_indexing(r),
|
||||
)
|
||||
|
||||
|
||||
@@ -260,9 +263,8 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.prune.fenced:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if rcp.is_pruning(r):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Pruning task already in progress.",
|
||||
|
||||
@@ -9,13 +9,13 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
from google.oauth2.credentials import Credentials # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
@@ -35,7 +35,6 @@ from danswer.connectors.gmail.connector_auth import (
|
||||
)
|
||||
from danswer.connectors.gmail.connector_auth import upsert_google_app_gmail_cred
|
||||
from danswer.connectors.google_drive.connector_auth import build_service_account_creds
|
||||
from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.connector_auth import delete_google_app_cred
|
||||
from danswer.connectors.google_drive.connector_auth import delete_service_account_key
|
||||
from danswer.connectors.google_drive.connector_auth import get_auth_url
|
||||
@@ -44,13 +43,13 @@ from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_authorized_user,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import get_service_account_key
|
||||
from danswer.connectors.google_drive.connector_auth import GOOGLE_DRIVE_SCOPES
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
update_credential_access_tokens,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import upsert_google_app_cred
|
||||
from danswer.connectors.google_drive.connector_auth import upsert_service_account_key
|
||||
from danswer.connectors.google_drive.connector_auth import verify_csrf
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.db.connector import create_connector
|
||||
from danswer.db.connector import delete_connector
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
@@ -84,7 +83,6 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import AuthStatus
|
||||
from danswer.server.documents.models import AuthUrl
|
||||
@@ -296,7 +294,7 @@ def upsert_service_account_credential(
|
||||
try:
|
||||
credential_base = build_service_account_creds(
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
primary_admin_email=service_account_credential_request.google_drive_primary_admin,
|
||||
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
|
||||
)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -322,7 +320,7 @@ def upsert_gmail_service_account_credential(
|
||||
try:
|
||||
credential_base = build_service_account_creds(
|
||||
DocumentSource.GMAIL,
|
||||
primary_admin_email=service_account_credential_request.gmail_delegated_user,
|
||||
delegated_user_email=service_account_credential_request.gmail_delegated_user,
|
||||
)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -350,14 +348,27 @@ def check_drive_tokens(
|
||||
return AuthStatus(authenticated=False)
|
||||
token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
google_drive_creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=token_json_str,
|
||||
scopes=GOOGLE_DRIVE_SCOPES,
|
||||
token_json_str=token_json_str
|
||||
)
|
||||
if google_drive_creds is None:
|
||||
return AuthStatus(authenticated=False)
|
||||
return AuthStatus(authenticated=True)
|
||||
|
||||
|
||||
@router.get("/admin/connector/google-drive/authorize/{credential_id}")
|
||||
def admin_google_drive_auth(
|
||||
response: Response, credential_id: str, _: User = Depends(current_admin_user)
|
||||
) -> AuthUrl:
|
||||
# set a cookie that we can read in the callback (used for `verify_csrf`)
|
||||
response.set_cookie(
|
||||
key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME,
|
||||
value=credential_id,
|
||||
httponly=True,
|
||||
max_age=600,
|
||||
)
|
||||
return AuthUrl(auth_url=get_auth_url(credential_id=int(credential_id)))
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload")
|
||||
def upload_files(
|
||||
files: list[UploadFile],
|
||||
@@ -486,10 +497,12 @@ def get_connector_indexing_status(
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
indexing_statuses: list[ConnectorIndexingStatus] = []
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# NOTE: If the connector is deleting behind the scenes,
|
||||
# accessing cc_pairs can be inconsistent and members like
|
||||
# connector or credential may be None.
|
||||
# Additional checks are done to make sure the connector and credential still exist.
|
||||
# Additional checks are done to make sure the connector and credential still exists.
|
||||
# TODO: make this one query ... possibly eager load or wrap in a read transaction
|
||||
# to avoid the complexity of trying to error check throughout the function
|
||||
cc_pairs = get_connector_credential_pairs(
|
||||
@@ -556,9 +569,8 @@ def get_connector_indexing_status(
|
||||
|
||||
in_progress = False
|
||||
if search_settings:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
|
||||
if r.exists(rci.fence_key):
|
||||
in_progress = True
|
||||
|
||||
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
|
||||
@@ -939,11 +951,10 @@ def google_drive_callback(
|
||||
)
|
||||
credential_id = int(credential_id_cookie)
|
||||
verify_csrf(credential_id, callback.state)
|
||||
|
||||
credentials: Credentials | None = update_credential_access_tokens(
|
||||
callback.code, credential_id, user, db_session
|
||||
)
|
||||
if credentials is None:
|
||||
if (
|
||||
update_credential_access_tokens(callback.code, credential_id, user, db_session)
|
||||
is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Unable to fetch Google Drive access tokens"
|
||||
)
|
||||
|
||||
@@ -81,6 +81,18 @@ def get_cc_source_full_info(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/credential/{id}")
|
||||
def list_credentials_by_id(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CredentialSnapshot]:
|
||||
credentials = fetch_credentials(db_session=db_session, user=user)
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/admin/credential/{credential_id}")
|
||||
def delete_credential_by_id_admin(
|
||||
credential_id: int,
|
||||
|
||||
@@ -377,16 +377,16 @@ class GoogleServiceAccountKey(BaseModel):
|
||||
|
||||
|
||||
class GoogleServiceAccountCredentialRequest(BaseModel):
|
||||
google_drive_primary_admin: str | None = None # email of user to impersonate
|
||||
google_drive_delegated_user: str | None = None # email of user to impersonate
|
||||
gmail_delegated_user: str | None = None # email of user to impersonate
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest":
|
||||
if (self.google_drive_primary_admin is None) == (
|
||||
if (self.google_drive_delegated_user is None) == (
|
||||
self.gmail_delegated_user is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Exactly one of google_drive_primary_admin or gmail_delegated_user must be set"
|
||||
"Exactly one of google_drive_delegated_user or gmail_delegated_user must be set"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from danswer.db.models import StarterMessage
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.document_set.models import DocumentSet
|
||||
from danswer.server.features.prompt.models import PromptSnapshot
|
||||
from danswer.server.features.tool.models import ToolSnapshot
|
||||
from danswer.server.features.tool.api import ToolSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -18,16 +18,10 @@ from danswer.db.tools import update_tool
|
||||
from danswer.server.features.tool.models import CustomToolCreate
|
||||
from danswer.server.features.tool.models import CustomToolUpdate
|
||||
from danswer.server.features.tool.models import ToolSnapshot
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
openapi_to_method_specs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.utils import is_image_generation_available
|
||||
|
||||
router = APIRouter(prefix="/tool")
|
||||
|
||||
@@ -57,7 +57,6 @@ class UserInfo(BaseModel):
|
||||
oidc_expiry: datetime | None = None
|
||||
current_token_created_at: datetime | None = None
|
||||
current_token_expiry_length: int | None = None
|
||||
is_cloud_superuser: bool = False
|
||||
organization_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
@@ -66,7 +65,6 @@ class UserInfo(BaseModel):
|
||||
user: User,
|
||||
current_token_created_at: datetime | None = None,
|
||||
expiry_length: int | None = None,
|
||||
is_cloud_superuser: bool = False,
|
||||
organization_name: str | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
@@ -92,7 +90,6 @@ class UserInfo(BaseModel):
|
||||
oidc_expiry=user.oidc_expiry if TRACK_EXTERNAL_IDP_EXPIRY else None,
|
||||
current_token_created_at=current_token_created_at,
|
||||
current_token_expiry_length=expiry_length,
|
||||
is_cloud_superuser=is_cloud_superuser,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,10 +35,9 @@ from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.auth import get_total_users_count
|
||||
from danswer.db.auth import get_total_users
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
@@ -191,6 +190,7 @@ def bulk_invite_users(
|
||||
)
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
normalized_emails = []
|
||||
try:
|
||||
for email in emails:
|
||||
@@ -206,7 +206,6 @@ def bulk_invite_users(
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
add_users_to_tenant(normalized_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
raise HTTPException(
|
||||
@@ -214,8 +213,6 @@ def bulk_invite_users(
|
||||
detail="User has already been invited to a Danswer organization",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
|
||||
initial_invited_users = get_invited_users()
|
||||
|
||||
@@ -227,7 +224,7 @@ def bulk_invite_users(
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
register_tenant_users(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session)
|
||||
)
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
@@ -263,7 +260,7 @@ def remove_invited_user(
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
register_tenant_users(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session)
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
@@ -477,7 +474,6 @@ def verify_user_logged_in(
|
||||
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
|
||||
# to enforce user verification here - the frontend always wants to get the info about
|
||||
# the current user regardless of if they are currently verified
|
||||
|
||||
if user is None:
|
||||
# if auth type is disabled, return a dummy user with preferences from
|
||||
# the key-value store
|
||||
@@ -504,7 +500,6 @@ def verify_user_logged_in(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
||||
is_cloud_superuser=user.email in SUPER_USERS,
|
||||
organization_name=organization_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@@ -284,14 +283,13 @@ def delete_chat_session_by_id(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
|
||||
def is_connected_sync() -> bool:
|
||||
def is_disconnected_sync() -> bool:
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
is_connected = not future.result(timeout=0.01)
|
||||
return is_connected
|
||||
return not future.result(timeout=0.01)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out")
|
||||
return True
|
||||
@@ -302,7 +300,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
)
|
||||
return True
|
||||
|
||||
return is_connected_sync
|
||||
return is_disconnected_sync
|
||||
|
||||
|
||||
@router.post("/send-message")
|
||||
@@ -311,7 +309,7 @@ def handle_new_chat_message(
|
||||
request: Request,
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
This endpoint is both used for all the following purposes:
|
||||
@@ -327,7 +325,7 @@ def handle_new_chat_message(
|
||||
request (Request): The current HTTP request context.
|
||||
user (User | None): The current user, obtained via dependency injection.
|
||||
_ (None): Rate limit check is run if user/group/global rate limits are enabled.
|
||||
is_connected_func (Callable[[], bool]): Function to check client disconnection,
|
||||
is_disconnected_func (Callable[[], bool]): Function to check client disconnection,
|
||||
used to stop the streaming response if the client disconnects.
|
||||
|
||||
Returns:
|
||||
@@ -342,6 +340,8 @@ def handle_new_chat_message(
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
import json
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_chat_message(
|
||||
@@ -354,7 +354,7 @@ def handle_new_chat_message(
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
is_connected=is_connected_func,
|
||||
is_connected=is_disconnected_func,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
@@ -362,9 +362,6 @@ def handle_new_chat_message(
|
||||
logger.exception(f"Error in chat message streaming: {e}")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
finally:
|
||||
logger.debug("Stream generator finished")
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@@ -557,9 +554,9 @@ def upload_files_for_chat(
|
||||
_: User | None = Depends(current_user),
|
||||
) -> dict[str, list[FileDescriptor]]:
|
||||
image_content_types = {"image/jpeg", "image/png", "image/webp"}
|
||||
csv_content_types = {"text/csv"}
|
||||
text_content_types = {
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-config",
|
||||
@@ -578,10 +575,8 @@ def upload_files_for_chat(
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
allowed_content_types = (
|
||||
image_content_types.union(text_content_types)
|
||||
.union(document_content_types)
|
||||
.union(csv_content_types)
|
||||
allowed_content_types = image_content_types.union(text_content_types).union(
|
||||
document_content_types
|
||||
)
|
||||
|
||||
for file in files:
|
||||
@@ -591,10 +586,6 @@ def upload_files_for_chat(
|
||||
elif file.content_type in text_content_types:
|
||||
error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, "
|
||||
".log, .tsv."
|
||||
elif file.content_type in csv_content_types:
|
||||
error_detail = (
|
||||
"Unsupported CSV file type. Supported CSV types include .csv."
|
||||
)
|
||||
else:
|
||||
error_detail = (
|
||||
"Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, "
|
||||
@@ -620,10 +611,6 @@ def upload_files_for_chat(
|
||||
file_type = ChatFileType.IMAGE
|
||||
# Convert image to JPEG
|
||||
file_content, new_content_type = convert_to_jpeg(file)
|
||||
elif file.content_type in csv_content_types:
|
||||
file_type = ChatFileType.CSV
|
||||
file_content = io.BytesIO(file.file.read())
|
||||
new_content_type = file.content_type or ""
|
||||
elif file.content_type in document_content_types:
|
||||
file_type = ChatFileType.DOC
|
||||
file_content = io.BytesIO(file.file.read())
|
||||
|
||||
@@ -188,7 +188,7 @@ class ChatMessageDetail(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
citations: dict[int, int] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
|
||||
@@ -21,7 +21,7 @@ from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CustomToolCallSummary,
|
||||
)
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
def build_user_message_for_non_tool_calling_llm(
|
||||
message: HumanMessage,
|
||||
tool_name: str,
|
||||
*args: "ToolResponse",
|
||||
) -> str:
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
tool_run_summary = cast("CustomToolCallSummary", args[0].response).tool_result
|
||||
return f"""
|
||||
Here's the result from the {tool_name} tool:
|
||||
|
||||
{tool_run_summary}
|
||||
|
||||
Now respond to the following:
|
||||
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
|
||||
class BaseTool(Tool):
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
tool_call_summary: "ToolCallSummary",
|
||||
tool_responses: list["ToolResponse"],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> "AnswerPromptBuilder":
|
||||
if using_tool_calling_llm:
|
||||
prompt_builder.append_message(tool_call_summary.tool_call_request)
|
||||
prompt_builder.append_message(tool_call_summary.tool_call_result)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_non_tool_calling_llm(
|
||||
prompt_builder.user_message_and_token_cnt[0],
|
||||
self.name,
|
||||
*tool_responses,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
@@ -9,13 +9,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -11,34 +11,24 @@ from pydantic import BaseModel
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.base_tool import BaseTool
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import USE_TOOL
|
||||
from danswer.tools.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_url
|
||||
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
|
||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_USER_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
TOOL_ARG_SYSTEM_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
TOOL_ARG_USER_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
openapi_to_method_specs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import openapi_to_url
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BODY
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.headers import header_list_to_header_dict
|
||||
from danswer.utils.headers import HeaderItemDict
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -53,7 +43,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_result: ToolResultType
|
||||
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
class CustomTool(Tool):
|
||||
def __init__(
|
||||
self,
|
||||
method_spec: MethodSpec,
|
||||
21
backend/danswer/tools/custom/custom_tool_prompt_builder.py
Normal file
21
backend/danswer/tools/custom/custom_tool_prompt_builder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
def build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
query: str,
|
||||
tool_name: str,
|
||||
*args: ToolResponse,
|
||||
) -> str:
|
||||
tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result
|
||||
return f"""
|
||||
Here's the result from the {tool_name} tool:
|
||||
|
||||
{tool_run_summary}
|
||||
|
||||
Now respond to the following:
|
||||
|
||||
{query}
|
||||
""".strip()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user