mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-07 08:35:47 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
233713cde3 | ||
|
|
c0b17b4c51 | ||
|
|
15f30b0050 | ||
|
|
39d9df9b1b |
@@ -65,7 +65,6 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -13,10 +13,7 @@ on:
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -198,9 +195,6 @@ jobs:
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
14
.github/workflows/pr-chromatic-tests.yml
vendored
14
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -3,7 +3,12 @@ concurrency:
|
||||
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -11,8 +16,6 @@ env:
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
name: Playwright Tests
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
@@ -105,7 +108,7 @@ jobs:
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -190,8 +193,7 @@ jobs:
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
chromatic-tests:
|
||||
name: Chromatic Tests
|
||||
|
||||
name: Run Chromatic
|
||||
needs: playwright-tests
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
"""display custom llm models
|
||||
|
||||
Revision ID: 177de57c21c9
|
||||
Revises: 4ee1287bd26a
|
||||
Create Date: 2024-11-21 11:49:04.488677
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import and_
|
||||
|
||||
revision = "177de57c21c9"
|
||||
down_revision = "4ee1287bd26a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
llm_provider = sa.table(
|
||||
"llm_provider",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("provider", sa.String),
|
||||
sa.column("model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
|
||||
)
|
||||
|
||||
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
|
||||
|
||||
providers_to_update = sa.select(
|
||||
llm_provider.c.id,
|
||||
llm_provider.c.model_names,
|
||||
llm_provider.c.display_model_names,
|
||||
).where(
|
||||
and_(
|
||||
~llm_provider.c.provider.in_(excluded_providers),
|
||||
llm_provider.c.model_names.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
results = conn.execute(providers_to_update).fetchall()
|
||||
|
||||
for provider_id, model_names, display_model_names in results:
|
||||
if display_model_names is None:
|
||||
display_model_names = []
|
||||
|
||||
combined_model_names = list(set(display_model_names + model_names))
|
||||
update_stmt = (
|
||||
llm_provider.update()
|
||||
.where(llm_provider.c.id == provider_id)
|
||||
.values(display_model_names=combined_model_names)
|
||||
)
|
||||
conn.execute(update_stmt)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,6 +1,5 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
@@ -15,9 +14,7 @@ from celery.signals import worker_shutdown
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.celery.tasks.indexing.tasks import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
@@ -98,15 +95,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
role: str = cast(str, info.get("role"))
|
||||
connected_slaves: int = info.get("connected_slaves", 0)
|
||||
|
||||
logger.info(
|
||||
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
|
||||
)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
@@ -16,7 +17,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -78,7 +78,7 @@ def document_batch_to_ids(
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
@@ -111,15 +111,10 @@ def extract_ids_from_runnable_connector(
|
||||
for doc_batch in doc_batch_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
raise RuntimeError("Stop signal received")
|
||||
callback.progress(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
fence_payload = RedisConnectorDeletePayload(
|
||||
fence_payload = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -66,9 +66,9 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
|
||||
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
@@ -17,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -33,8 +33,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import delete_index_attempt
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@@ -44,11 +42,9 @@ from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -61,7 +57,7 @@ from shared_configs.configs import SENTRY_DSN
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
def __init__(
|
||||
self,
|
||||
stop_key: str,
|
||||
@@ -77,7 +73,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = ""
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
@@ -85,17 +80,15 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
return True
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
def progress(self, amount: int) -> None:
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"RunIndexingCallback - lock.reacquire exceptioned. "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
@@ -104,54 +97,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
# inner = index_attempt in non terminal state
|
||||
# outer = r.fence_key down
|
||||
|
||||
# check the db for index attempts in a non terminal state
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
|
||||
# if the fence is down / doesn't exist, possible error but not confirmed
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
# Between the time the attempts are first looked up and the time we see the fence down,
|
||||
# the attempt may have completed and taken down the fence normally.
|
||||
|
||||
# We need to double check that the index attempt is still in a non terminal state
|
||||
# and matches the original state, which confirms we are really in a bad state.
|
||||
attempt_2 = get_index_attempt(db_session, attempt.id)
|
||||
if not attempt_2:
|
||||
continue
|
||||
|
||||
if attempt.status != attempt_2.status:
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
@@ -162,7 +107,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -172,7 +117,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
@@ -191,18 +135,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# gather cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
@@ -259,29 +198,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Unfenced index attempt found in DB: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
task_logger.error(failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -291,11 +207,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
return tasks_created
|
||||
|
||||
@@ -400,11 +311,10 @@ def try_creating_indexing_task(
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
index_attempt_id: int | None = None
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -455,8 +365,6 @@ def try_creating_indexing_task(
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
@@ -477,16 +385,13 @@ def try_creating_indexing_task(
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
except Exception:
|
||||
redis_connector_index.set_fence(None)
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if index_attempt_id is not None:
|
||||
delete_index_attempt(db_session, index_attempt_id)
|
||||
redis_connector_index.set_fence(None)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
@@ -504,7 +409,7 @@ def connector_indexing_proxy_task(
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"Indexing proxy - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -512,7 +417,7 @@ def connector_indexing_proxy_task(
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task_wrapper,
|
||||
connector_indexing_task,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
@@ -523,7 +428,7 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
|
||||
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -531,7 +436,7 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -555,7 +460,7 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
f"Indexing proxy - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -567,7 +472,7 @@ def connector_indexing_proxy_task(
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"Indexing proxy - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -575,38 +480,6 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
|
||||
def connector_indexing_task_wrapper(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Just wraps connector_indexing_task so we can log any exceptions before
|
||||
re-raising it."""
|
||||
result: int | None = None
|
||||
|
||||
try:
|
||||
result = connector_indexing_task(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
@@ -661,7 +534,6 @@ def connector_indexing_task(
|
||||
if redis_connector.delete.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
)
|
||||
@@ -669,18 +541,18 @@ def connector_indexing_task(
|
||||
if redis_connector.stop.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
while True:
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
# wait for the fence to come up
|
||||
if not redis_connector_index.fenced:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload # The payload must exist
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
|
||||
@@ -703,7 +575,7 @@ def connector_indexing_task(
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -712,7 +584,7 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -747,7 +619,7 @@ def connector_indexing_task(
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
@@ -277,7 +277,7 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
|
||||
@@ -5,6 +5,7 @@ from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@@ -46,10 +47,13 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
@@ -645,26 +649,20 @@ def monitor_ccpair_indexing_taskset(
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
result_state = result.state
|
||||
if (
|
||||
result_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if status_int is None: # completion signal not set ... check for errors
|
||||
# If we get here, and then the task both sets the completion signal and finishes,
|
||||
# we will incorrectly abort the task. We must check result state, then check
|
||||
# get_completion again to avoid the race condition.
|
||||
if result_state in READY_STATES:
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
|
||||
# IF the task state is READY, THEN generator_complete should be set
|
||||
# if it isn't, then the worker crashed
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
@@ -699,6 +697,37 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
@@ -750,6 +779,25 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Unfenced index attempt found in DB: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -29,7 +31,7 @@ from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.logger import TaskAttemptSingleton
|
||||
@@ -40,6 +42,19 @@ logger = setup_logger()
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
class RunIndexingCallbackInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, amount: int) -> None:
|
||||
"""Send progress updates to the caller."""
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
@@ -91,7 +106,7 @@ def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
@@ -123,7 +138,13 @@ def _run_indexing(
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings,
|
||||
callback=callback,
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=index_attempt.id,
|
||||
db_session=db_session,
|
||||
# let the world know we're still making progress after
|
||||
# every 10 batches
|
||||
freq=10,
|
||||
),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -136,7 +157,6 @@ def _run_indexing(
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
@@ -208,9 +228,7 @@ def _run_indexing(
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_run_indexing: Connector stop signal detected"
|
||||
)
|
||||
raise RuntimeError("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
@@ -271,7 +289,7 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
callback.progress(len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
@@ -401,7 +419,7 @@ def run_indexing_entrypoint(
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
if is_ee:
|
||||
|
||||
@@ -7,9 +7,9 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from danswer.connectors.confluence.utils import attachment_to_content
|
||||
from danswer.connectors.confluence.utils import build_confluence_client
|
||||
from danswer.connectors.confluence.utils import build_confluence_document_id
|
||||
from danswer.connectors.confluence.utils import datetime_from_string
|
||||
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
@@ -70,7 +70,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
@@ -97,44 +97,39 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
labels_to_skip = list(set(labels_to_skip))
|
||||
comma_separated_labels = ",".join(
|
||||
f"'{quote(label)}'" for label in labels_to_skip
|
||||
)
|
||||
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials=credentials,
|
||||
self.confluence_client = build_confluence_client(
|
||||
credentials_json=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
comment_string = ""
|
||||
|
||||
comment_cql = f"type=comment and container='{page_id}'"
|
||||
comment_cql += self.cql_label_filter
|
||||
|
||||
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
|
||||
for comment in self.confluence_client.paginated_cql_retrieval(
|
||||
for comments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
cql=comment_cql,
|
||||
expand=expand,
|
||||
):
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
fetched_titles=set(),
|
||||
)
|
||||
for comment in comments:
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
)
|
||||
|
||||
return comment_string
|
||||
|
||||
@@ -146,6 +141,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
"""
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
|
||||
@@ -155,19 +153,16 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# Extract text from page
|
||||
if confluence_object["type"] == "page":
|
||||
object_text = extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=confluence_object,
|
||||
fetched_titles={confluence_object.get("title", "")},
|
||||
self.confluence_client, confluence_object
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
self.confluence_client, confluence_object
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parseable
|
||||
return None
|
||||
|
||||
# Get space name
|
||||
@@ -198,39 +193,44 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
for page in page_batch:
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
for attachments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
cql=attachment_cql,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
for attachment in attachments:
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -255,47 +255,52 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
for pages in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
for page in pages:
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
attachment["_links"]["webui"],
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachments in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
for attachment in attachments:
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
|
||||
@@ -20,10 +20,6 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -84,7 +80,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
@@ -99,10 +95,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
@@ -149,7 +141,7 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
"""
|
||||
@@ -161,43 +153,46 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
while url_suffix:
|
||||
try:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
next_response = self.get(url_suffix)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in confluence call to {url_suffix}")
|
||||
|
||||
# If the problematic expansion is in the url, replace it
|
||||
# with the replacement expansion and try again
|
||||
# If that fails, raise the error
|
||||
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
raise e
|
||||
logger.warning(
|
||||
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
|
||||
" and trying again."
|
||||
)
|
||||
url_suffix = url_suffix.replace(
|
||||
_PROBLEMATIC_EXPANSIONS,
|
||||
_REPLACEMENT_EXPANSIONS,
|
||||
)
|
||||
continue
|
||||
|
||||
# yield the results individually
|
||||
yield from next_response.get("results", [])
|
||||
|
||||
logger.exception("Error in danswer_cql: \n")
|
||||
raise e
|
||||
yield next_response.get("results", [])
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_cql_retrieval(
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
return self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
group_name = quote(group_name)
|
||||
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The content/search endpoint can be used to fetch pages, attachments, and comments.
|
||||
"""
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
yield from self._paginate_url(
|
||||
return self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_cql_page_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/content/search?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
@@ -206,7 +201,7 @@ class OnyxConfluence(Confluence):
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This function will paginate through the top level query first, then
|
||||
paginate through all of the expansions.
|
||||
@@ -226,110 +221,6 @@ class OnyxConfluence(Confluence):
|
||||
for item in data:
|
||||
_traverse_and_update(item)
|
||||
|
||||
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(confluence_object)
|
||||
yield confluence_object
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The search/user endpoint can be used to fetch users.
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user" if self.cloud else "rest/api/search"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
|
||||
def paginated_groups_by_user_retrieval(
|
||||
self,
|
||||
user: dict[str, Any],
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
user_field = "accountId" if self.cloud else "key"
|
||||
user_value = user["accountId"] if self.cloud else user["userKey"]
|
||||
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
|
||||
user_query = f"{user_field}={quote(user_value)}"
|
||||
|
||||
url = f"rest/api/user/memberof?{user_query}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
yield from self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch the members of a group.
|
||||
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
|
||||
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
|
||||
"""
|
||||
group_name = quote(group_name)
|
||||
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> None:
|
||||
# test connection with direct client, no retries
|
||||
confluence_client_without_retries = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
)
|
||||
spaces = confluence_client_without_retries.get_all_spaces(limit=1)
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(results)
|
||||
yield results
|
||||
|
||||
@@ -2,7 +2,6 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -72,9 +71,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
@@ -82,7 +79,7 @@ def extract_text_from_confluence_html(
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
@@ -104,72 +101,38 @@ def extract_text_from_confluence_html(
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
for html_page_reference in soup.findAll("ri:page"):
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
page_title = html_page_reference.attrs["ri:content-title"]
|
||||
if not page_title:
|
||||
continue
|
||||
|
||||
page_query = f"type=page and title='{page_title}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
for page_batch in confluence_client.paginated_cql_page_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
page_contents = page_batch[0]
|
||||
break
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
f"Error getting page contents for object {confluence_object}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
confluence_client, page_contents
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -269,3 +232,20 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
|
||||
) -> OnyxConfluence:
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials_json["confluence_username"] if is_cloud else None,
|
||||
password=credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=60,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
@@ -12,93 +12,129 @@ from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.danswer_jira.utils import best_effort_basic_expert_info
|
||||
from danswer.connectors.danswer_jira.utils import best_effort_get_field_from_issue
|
||||
from danswer.connectors.danswer_jira.utils import build_jira_client
|
||||
from danswer.connectors.danswer_jira.utils import build_jira_url
|
||||
from danswer.connectors.danswer_jira.utils import extract_jira_project
|
||||
from danswer.connectors.danswer_jira.utils import extract_text_from_adf
|
||||
from danswer.connectors.danswer_jira.utils import get_comment_strs
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
PROJECT_URL_PAT = "projects"
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
|
||||
def _paginate_jql_search(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
start = 0
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
parsed_url = urlparse(url)
|
||||
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
|
||||
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise Exception(f"Found Jira object not of type Issue: {issue}")
|
||||
# Split the path by '/' and find the position of 'projects' to get the project name
|
||||
split_path = parsed_url.path.split("/")
|
||||
if PROJECT_URL_PAT in split_path:
|
||||
project_pos = split_path.index(PROJECT_URL_PAT)
|
||||
if len(split_path) > project_pos + 1:
|
||||
jira_project = split_path[project_pos + 1]
|
||||
else:
|
||||
raise ValueError("No project name found in the URL")
|
||||
else:
|
||||
raise ValueError("'projects' not found in the URL")
|
||||
|
||||
if len(issues) < max_results:
|
||||
break
|
||||
return jira_base, jira_project
|
||||
|
||||
start += max_results
|
||||
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
|
||||
comment_strs.append(body_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process comment due to an error: {e}")
|
||||
continue
|
||||
|
||||
return comment_strs
|
||||
|
||||
|
||||
def fetch_jira_issues_batch(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
batch_size: int,
|
||||
start_index: int,
|
||||
jira_client: JIRA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
) -> Iterable[Document]:
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=jira_client,
|
||||
jql=jql,
|
||||
max_results=batch_size,
|
||||
):
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch = []
|
||||
|
||||
batch = jira_client.search_issues(
|
||||
jql,
|
||||
startAt=start_index,
|
||||
maxResults=batch_size,
|
||||
)
|
||||
|
||||
for jira in batch:
|
||||
if type(jira) != Issue:
|
||||
logger.warning(f"Found Jira object not of type Issue {jira}")
|
||||
continue
|
||||
|
||||
if labels_to_skip and any(
|
||||
label in jira.fields.labels for label in labels_to_skip
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it has a label to skip. Found "
|
||||
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
issue.fields.description
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
@@ -106,53 +142,66 @@ def fetch_jira_issues_batch(
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.creator.displayName,
|
||||
email=jira.fields.creator.emailAddress,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.assignee.displayName, # type: ignore
|
||||
email=jira.fields.assignee.emailAddress, # type: ignore
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
yield Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=issue.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class JiraConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_project_url: str,
|
||||
@@ -164,8 +213,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
|
||||
self._jira_client: JIRA | None = None
|
||||
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
|
||||
self.jira_client: JIRA | None = None
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
@@ -174,45 +223,54 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
return tuple(email.strip() for email in self._comment_email_blacklist)
|
||||
|
||||
@property
|
||||
def jira_client(self) -> JIRA:
|
||||
if self._jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
return self._jira_client
|
||||
|
||||
@property
|
||||
def quoted_jira_project(self) -> str:
|
||||
# Quote the project name to handle reserved words
|
||||
return f'"{self._jira_project}"'
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
credentials=credentials,
|
||||
jira_base=self.jira_base,
|
||||
)
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
email = credentials["jira_user_email"]
|
||||
self.jira_client = JIRA(
|
||||
basic_auth=(email, api_token),
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
else:
|
||||
self.jira_client = JIRA(
|
||||
token_auth=api_token,
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
|
||||
yield document_batch
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
@@ -220,54 +278,31 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {self.quoted_jira_project} AND "
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
|
||||
yield document_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
slim_doc_batch = []
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
fields="key",
|
||||
):
|
||||
issue_key = best_effort_get_field_from_issue(issue, "key")
|
||||
id = build_jira_url(self.jira_client, issue_key)
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=id,
|
||||
perm_sync_data=None,
|
||||
)
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=jql,
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
|
||||
yield slim_doc_batch
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,136 +1,17 @@
|
||||
"""Module with custom fields processing functions"""
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import CustomFieldOption
|
||||
from jira.resources import Issue
|
||||
from jira.resources import User
|
||||
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
PROJECT_URL_PAT = "projects"
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
|
||||
|
||||
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
|
||||
display_name = None
|
||||
email = None
|
||||
if hasattr(obj, "display_name"):
|
||||
display_name = obj.display_name
|
||||
else:
|
||||
display_name = obj.get("displayName")
|
||||
|
||||
if hasattr(obj, "emailAddress"):
|
||||
email = obj.emailAddress
|
||||
else:
|
||||
email = obj.get("emailAddress")
|
||||
|
||||
if not email and not display_name:
|
||||
return None
|
||||
|
||||
return BasicExpertInfo(display_name=display_name, email=email)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
|
||||
return f"{jira_client.client_info()}/browse/{issue_key}"
|
||||
|
||||
|
||||
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
email = credentials["jira_user_email"]
|
||||
return JIRA(
|
||||
basic_auth=(email, api_token),
|
||||
server=jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
else:
|
||||
return JIRA(
|
||||
token_auth=api_token,
|
||||
server=jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
|
||||
|
||||
def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
parsed_url = urlparse(url)
|
||||
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
|
||||
|
||||
# Split the path by '/' and find the position of 'projects' to get the project name
|
||||
split_path = parsed_url.path.split("/")
|
||||
if PROJECT_URL_PAT in split_path:
|
||||
project_pos = split_path.index(PROJECT_URL_PAT)
|
||||
if len(split_path) > project_pos + 1:
|
||||
jira_project = split_path[project_pos + 1]
|
||||
else:
|
||||
raise ValueError("No project name found in the URL")
|
||||
else:
|
||||
raise ValueError("'projects' not found in the URL")
|
||||
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def get_comment_strs(
|
||||
issue: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
comment_strs = []
|
||||
for comment in issue.fields.comment.comments:
|
||||
try:
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
|
||||
comment_strs.append(body_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process comment due to an error: {e}")
|
||||
continue
|
||||
|
||||
return comment_strs
|
||||
|
||||
|
||||
class CustomFieldExtractor:
|
||||
@staticmethod
|
||||
def _process_custom_field_value(value: Any) -> str:
|
||||
|
||||
@@ -2,7 +2,6 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
@@ -49,67 +48,6 @@ def _extract_sections_basic(
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
try:
|
||||
sheets_service = build(
|
||||
"sheets", "v4", credentials=service._http.credentials
|
||||
)
|
||||
spreadsheet = (
|
||||
sheets_service.spreadsheets()
|
||||
.get(spreadsheetId=file["id"])
|
||||
.execute()
|
||||
)
|
||||
|
||||
sections = []
|
||||
for sheet in spreadsheet["sheets"]:
|
||||
sheet_name = sheet["properties"]["title"]
|
||||
sheet_id = sheet["properties"]["sheetId"]
|
||||
|
||||
# Get sheet dimensions
|
||||
grid_properties = sheet["properties"].get("gridProperties", {})
|
||||
row_count = grid_properties.get("rowCount", 1000)
|
||||
column_count = grid_properties.get("columnCount", 26)
|
||||
|
||||
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
|
||||
end_column = ""
|
||||
while column_count:
|
||||
column_count, remainder = divmod(column_count - 1, 26)
|
||||
end_column = chr(65 + remainder) + end_column
|
||||
|
||||
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
|
||||
|
||||
try:
|
||||
result = (
|
||||
sheets_service.spreadsheets()
|
||||
.values()
|
||||
.get(spreadsheetId=file["id"], range=range_name)
|
||||
.execute()
|
||||
)
|
||||
values = result.get("values", [])
|
||||
|
||||
if values:
|
||||
text = f"Sheet: {sheet_name}\n"
|
||||
for row in values:
|
||||
text += "\t".join(str(cell) for cell in row) + "\n"
|
||||
sections.append(
|
||||
Section(
|
||||
link=f"{link}#gid={sheet_id}",
|
||||
text=text,
|
||||
)
|
||||
)
|
||||
except HttpError as e:
|
||||
logger.warning(
|
||||
f"Error fetching data for sheet '{sheet_name}': {e}"
|
||||
)
|
||||
continue
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
@@ -127,7 +65,6 @@ def _extract_sections_basic(
|
||||
.decode("utf-8")
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
|
||||
@@ -197,9 +197,7 @@ class SlackbotHandler:
|
||||
return
|
||||
|
||||
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
|
||||
tokens_changed = (
|
||||
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
|
||||
)
|
||||
tokens_changed = slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
|
||||
if not tokens_exist or tokens_changed:
|
||||
if tokens_exist:
|
||||
logger.info(
|
||||
|
||||
@@ -67,13 +67,6 @@ def create_index_attempt(
|
||||
return new_attempt.id
|
||||
|
||||
|
||||
def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None:
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if index_attempt:
|
||||
db_session.delete(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mock_successful_index_attempt(
|
||||
connector_credential_pair_id: int,
|
||||
search_settings_id: int,
|
||||
|
||||
@@ -1181,7 +1181,7 @@ class LLMProvider(Base):
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# Models to actually display to users
|
||||
# Models to actually disp;aly to users
|
||||
# If nulled out, we assume in the application logic we should present all
|
||||
display_model_names: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
|
||||
@@ -259,6 +259,7 @@ def get_personas(
|
||||
) -> Sequence[Persona]:
|
||||
stmt = select(Persona).distinct()
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -125,7 +125,7 @@ class Chunker:
|
||||
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@@ -134,7 +134,7 @@ class Chunker:
|
||||
self.enable_multipass = enable_multipass
|
||||
self.enable_large_chunks = enable_large_chunks
|
||||
self.tokenizer = tokenizer
|
||||
self.callback = callback
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
self.blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
@@ -356,14 +356,9 @@ class Chunker:
|
||||
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
|
||||
final_chunks: list[DocAwareChunk] = []
|
||||
for document in documents:
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("Chunker.chunk: Stop signal detected")
|
||||
final_chunks.extend(self._handle_single_document(document))
|
||||
|
||||
chunks = self._handle_single_document(document)
|
||||
final_chunks.extend(chunks)
|
||||
|
||||
if self.callback:
|
||||
self.callback.progress("Chunker.chunk", len(chunks))
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
|
||||
return final_chunks
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
@@ -34,7 +34,7 @@ class IndexingEmbedder(ABC):
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
heartbeat: Heartbeat | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
@@ -60,7 +60,7 @@ class IndexingEmbedder(ABC):
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
callback=callback,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -83,7 +83,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -95,7 +95,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url,
|
||||
api_version,
|
||||
deployment_name,
|
||||
callback,
|
||||
heartbeat,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
@@ -201,9 +201,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
@classmethod
|
||||
def from_db_search_settings(
|
||||
cls,
|
||||
search_settings: SearchSettings,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
|
||||
) -> "DefaultIndexingEmbedder":
|
||||
return cls(
|
||||
model_name=search_settings.model_name,
|
||||
@@ -215,5 +213,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url=search_settings.api_url,
|
||||
api_version=search_settings.api_version,
|
||||
deployment_name=search_settings.deployment_name,
|
||||
callback=callback,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,41 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingHeartbeatInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
class Heartbeat(abc.ABC):
|
||||
"""Useful for any long-running work that goes through a bunch of items
|
||||
and needs to occasionally give updates on progress.
|
||||
e.g. chunking, embedding, updating vespa, etc."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
@abc.abstractmethod
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
"""Send progress updates to the caller."""
|
||||
|
||||
class IndexingHeartbeat(Heartbeat):
|
||||
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
|
||||
self.cnt = 0
|
||||
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.db_session = db_session
|
||||
self.freq = freq
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.cnt += 1
|
||||
if self.cnt % self.freq == 0:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=self.db_session, index_attempt_id=self.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
index_attempt.time_updated = func.now()
|
||||
self.db_session.commit()
|
||||
else:
|
||||
logger.error("Index attempt not found, this should not happen!")
|
||||
|
||||
@@ -34,7 +34,7 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -414,7 +414,6 @@ def build_indexing_pipeline(
|
||||
ignore_time_skip: bool = False,
|
||||
attempt_id: int | None = None,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
@@ -441,8 +440,13 @@ def build_indexing_pipeline(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=multipass,
|
||||
enable_large_chunks=enable_large_chunks,
|
||||
# after every doc, update status in case there are a bunch of really long docs
|
||||
callback=callback,
|
||||
# after every doc, update status in case there are a bunch of
|
||||
# really long docs
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=attempt_id, db_session=db_session, freq=1
|
||||
)
|
||||
if attempt_id
|
||||
else None,
|
||||
)
|
||||
|
||||
return partial(
|
||||
|
||||
@@ -231,16 +231,16 @@ class QuotesProcessor:
|
||||
|
||||
model_previous = self.model_output
|
||||
self.model_output += token
|
||||
|
||||
if not self.found_answer_start:
|
||||
m = answer_pattern.search(self.model_output)
|
||||
if m:
|
||||
self.found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations
|
||||
if self.is_json_prompt and len(self.model_output) > 400:
|
||||
self.found_answer_end = True
|
||||
if self.is_json_prompt and len(self.model_output) > 70:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
logger.debug("Model output thus far:", self.model_output)
|
||||
self.found_answer_end = True
|
||||
return
|
||||
|
||||
remaining = self.model_output[m.end() :]
|
||||
|
||||
@@ -16,7 +16,7 @@ from danswer.configs.model_configs import (
|
||||
)
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -99,7 +99,7 @@ class EmbeddingModel:
|
||||
api_url: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
retrim_content: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -116,7 +116,7 @@ class EmbeddingModel:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
self.callback = callback
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@@ -160,10 +160,6 @@ class EmbeddingModel:
|
||||
|
||||
embeddings: list[Embedding] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("_batch_encode_texts detected stop signal")
|
||||
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
@@ -183,8 +179,8 @@ class EmbeddingModel:
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
||||
@@ -17,7 +17,7 @@ from danswer.db.document import construct_document_select_for_connector_credenti
|
||||
from danswer.db.models import Document as DbDocument
|
||||
|
||||
|
||||
class RedisConnectorDeletePayload(BaseModel):
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
||||
|
||||
@@ -54,18 +54,20 @@ class RedisConnectorDelete:
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorDeletePayload | None:
|
||||
def payload(self) -> RedisConnectorDeletionFenceData | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorDeletePayload.model_validate_json(cast(str, fence_str))
|
||||
payload = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(self, payload: RedisConnectorDeletePayload | None) -> None:
|
||||
def set_fence(self, payload: RedisConnectorDeletionFenceData | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
@@ -30,6 +30,7 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/llm")
|
||||
basic_router = APIRouter(prefix="/llm")
|
||||
|
||||
|
||||
@@ -4,10 +4,6 @@ import re
|
||||
import string
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
ESCAPE_SEQUENCE_RE = re.compile(
|
||||
r"""
|
||||
@@ -81,8 +77,7 @@ def extract_embedded_json(s: str) -> dict:
|
||||
last_brace_index = s.rfind("}")
|
||||
|
||||
if first_brace_index == -1 or last_brace_index == -1:
|
||||
logger.warning("No valid json found, assuming answer is entire string")
|
||||
return {"answer": s, "quotes": []}
|
||||
raise ValueError("No valid json found")
|
||||
|
||||
json_str = s[first_brace_index : last_brace_index + 1]
|
||||
try:
|
||||
|
||||
@@ -411,8 +411,6 @@ def _validate_curator_status__no_commit(
|
||||
.all()
|
||||
)
|
||||
|
||||
# if the user is a curator in any of their groups, set their role to CURATOR
|
||||
# otherwise, set their role to BASIC
|
||||
if curator_relationships:
|
||||
user.role = UserRole.CURATOR
|
||||
elif user.role == UserRole.CURATOR:
|
||||
@@ -438,15 +436,6 @@ def update_user_curator_relationship(
|
||||
user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=set_curator_request.user_id,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from typing import Any
|
||||
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from danswer.connectors.confluence.utils import build_confluence_client
|
||||
from danswer.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -9,30 +11,26 @@ from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
def _get_group_members_email_paginated(
|
||||
confluence_client: OnyxConfluence,
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
user = user_result["user"]
|
||||
email = user.get("email")
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
members: list[dict[str, Any]] = []
|
||||
for member_batch in confluence_client.paginated_group_members_retrieval(group_name):
|
||||
members.extend(member_batch)
|
||||
|
||||
group_member_emails: set[str] = set()
|
||||
for member in members:
|
||||
email = member.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
user_name = user.get("username")
|
||||
# If it is present, try to get the email using a Server-specific method
|
||||
user_name = member.get("username")
|
||||
if user_name:
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client,
|
||||
user_name=user_name,
|
||||
)
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
continue
|
||||
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
if email:
|
||||
group_member_emails.add(email)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
@@ -40,20 +38,31 @@ def _build_group_member_email_map(
|
||||
def confluence_group_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
confluence_client = build_confluence_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
|
||||
credentials_json=cc_pair.credential.credential_json,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
)
|
||||
# Get all group names
|
||||
group_names: list[str] = []
|
||||
for group_batch in confluence_client.paginated_groups_retrieval():
|
||||
for group in group_batch:
|
||||
if group_name := group.get("name"):
|
||||
group_names.append(group_name)
|
||||
|
||||
# For each group name, get all members and create a danswer group
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
for group_name in group_names:
|
||||
group_member_emails = _get_group_members_email_paginated(
|
||||
confluence_client, group_name
|
||||
)
|
||||
if not group_member_emails:
|
||||
continue
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_id,
|
||||
id=group_name,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -55,12 +55,7 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
}
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 60 seconds
|
||||
DocumentSource.GOOGLE_DRIVE: 60,
|
||||
DocumentSource.CONFLUENCE: 60,
|
||||
}
|
||||
EXTERNAL_GROUP_SYNC_PERIOD: int = 30 # 30 seconds
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_connector_creation(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connectors
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair_info = CCPairManager.get_single(
|
||||
cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
assert cc_pair_info
|
||||
assert cc_pair_info.creator
|
||||
assert str(cc_pair_info.creator) == admin_user.id
|
||||
assert cc_pair_info.creator_email == admin_user.email
|
||||
|
||||
|
||||
def test_overlapping_connector_creation(reset: None) -> None:
|
||||
"""Tests that connectors indexing the same documents don't interfere with each other.
|
||||
A previous bug involved document by cc pair entries not being added for new connectors
|
||||
when the docs existed already via another connector and were up to date relative to the source.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
config = {
|
||||
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
"space": "DailyConne",
|
||||
"is_cloud": True,
|
||||
"page_id": "",
|
||||
}
|
||||
|
||||
credential = {
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
}
|
||||
|
||||
# store the time before we create the connector so that we know after
|
||||
# when the indexing should have started
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# create connector
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
connector_specific_config=config,
|
||||
credential_json=credential,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair_1, now, timeout=120, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
cc_pair_2 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
connector_specific_config=config,
|
||||
credential_json=credential,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair_2, now, timeout=120, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
info_1 = CCPairManager.get_single(cc_pair_1.id, user_performing_action=admin_user)
|
||||
assert info_1
|
||||
|
||||
info_2 = CCPairManager.get_single(cc_pair_2.id, user_performing_action=admin_user)
|
||||
assert info_2
|
||||
|
||||
assert info_1.num_docs_indexed == info_2.num_docs_indexed
|
||||
@@ -29,6 +29,78 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def test_connector_creation(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connectors
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair_info = CCPairManager.get_single(
|
||||
cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
assert cc_pair_info
|
||||
assert cc_pair_info.creator
|
||||
assert str(cc_pair_info.creator) == admin_user.id
|
||||
assert cc_pair_info.creator_email == admin_user.email
|
||||
|
||||
|
||||
# TODO(rkuo): will enable this once i have credentials on github
|
||||
# def test_overlapping_connector_creation(reset: None) -> None:
|
||||
# # Creating an admin user (first user created is automatically an admin)
|
||||
# admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# config = {
|
||||
# "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
# "space": os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
# "is_cloud": True,
|
||||
# "page_id": "",
|
||||
# }
|
||||
|
||||
# credential = {
|
||||
# "confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
# "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
# }
|
||||
|
||||
# # store the time before we create the connector so that we know after
|
||||
# # when the indexing should have started
|
||||
# now = datetime.now(timezone.utc)
|
||||
|
||||
# # create connector
|
||||
# cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# connector_specific_config=config,
|
||||
# credential_json=credential,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# CCPairManager.wait_for_indexing(
|
||||
# cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
# )
|
||||
|
||||
# cc_pair_2 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# connector_specific_config=config,
|
||||
# credential_json=credential,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# CCPairManager.wait_for_indexing(
|
||||
# cc_pair_2, now, timeout=60, user_performing_action=admin_user
|
||||
# )
|
||||
|
||||
# info_1 = CCPairManager.get_single(cc_pair_1.id)
|
||||
# assert info_1
|
||||
|
||||
# info_2 = CCPairManager.get_single(cc_pair_2.id)
|
||||
# assert info_2
|
||||
|
||||
# assert info_1.num_docs_indexed == info_2.num_docs_indexed
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
@@ -17,48 +18,49 @@ def mock_jira_client() -> MagicMock:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_issue_small() -> MagicMock:
|
||||
issue = MagicMock(spec=Issue)
|
||||
fields = MagicMock()
|
||||
fields.description = "Small description"
|
||||
fields.comment = MagicMock()
|
||||
fields.comment.comments = [
|
||||
issue = MagicMock()
|
||||
issue.key = "SMALL-1"
|
||||
issue.fields.description = "Small description"
|
||||
issue.fields.comment.comments = [
|
||||
MagicMock(body="Small comment 1"),
|
||||
MagicMock(body="Small comment 2"),
|
||||
]
|
||||
fields.creator = MagicMock()
|
||||
fields.creator.displayName = "John Doe"
|
||||
fields.creator.emailAddress = "john@example.com"
|
||||
fields.summary = "Small Issue"
|
||||
fields.updated = "2023-01-01T00:00:00+0000"
|
||||
fields.labels = []
|
||||
|
||||
issue.fields = fields
|
||||
issue.key = "SMALL-1"
|
||||
issue.fields.creator.displayName = "John Doe"
|
||||
issue.fields.creator.emailAddress = "john@example.com"
|
||||
issue.fields.summary = "Small Issue"
|
||||
issue.fields.updated = "2023-01-01T00:00:00+0000"
|
||||
issue.fields.labels = []
|
||||
return issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_issue_large() -> MagicMock:
|
||||
issue = MagicMock(spec=Issue)
|
||||
fields = MagicMock()
|
||||
fields.description = "a" * 99_000
|
||||
fields.comment = MagicMock()
|
||||
fields.comment.comments = [
|
||||
# This will be larger than 100KB
|
||||
issue = MagicMock()
|
||||
issue.key = "LARGE-1"
|
||||
issue.fields.description = "a" * 99_000
|
||||
issue.fields.comment.comments = [
|
||||
MagicMock(body="Large comment " * 1000),
|
||||
MagicMock(body="Another large comment " * 1000),
|
||||
]
|
||||
fields.creator = MagicMock()
|
||||
fields.creator.displayName = "Jane Doe"
|
||||
fields.creator.emailAddress = "jane@example.com"
|
||||
fields.summary = "Large Issue"
|
||||
fields.updated = "2023-01-02T00:00:00+0000"
|
||||
fields.labels = []
|
||||
|
||||
issue.fields = fields
|
||||
issue.key = "LARGE-1"
|
||||
issue.fields.creator.displayName = "Jane Doe"
|
||||
issue.fields.creator.emailAddress = "jane@example.com"
|
||||
issue.fields.summary = "Large Issue"
|
||||
issue.fields.updated = "2023-01-02T00:00:00+0000"
|
||||
issue.fields.labels = []
|
||||
return issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_type() -> Callable[[Any], type]:
|
||||
def _patched_type(obj: Any) -> type:
|
||||
if isinstance(obj, MagicMock):
|
||||
return Issue
|
||||
return type(obj)
|
||||
|
||||
return _patched_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_api_version() -> Generator[Any, Any, Any]:
|
||||
with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"):
|
||||
@@ -67,9 +69,11 @@ def mock_jira_api_version() -> Generator[Any, Any, Any]:
|
||||
|
||||
@pytest.fixture
|
||||
def patched_environment(
|
||||
patched_type: type,
|
||||
mock_jira_api_version: MockFixture,
|
||||
) -> Generator[Any, Any, Any]:
|
||||
yield
|
||||
with patch("danswer.connectors.danswer_jira.connector.type", patched_type):
|
||||
yield
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_small_ticket(
|
||||
@@ -79,8 +83,9 @@ def test_fetch_jira_issues_batch_small_ticket(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 1
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
assert "Small description" in docs[0].sections[0].text
|
||||
@@ -95,8 +100,9 @@ def test_fetch_jira_issues_batch_large_ticket(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_large]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 1
|
||||
assert len(docs) == 0 # The large ticket should be skipped
|
||||
|
||||
|
||||
@@ -108,8 +114,9 @@ def test_fetch_jira_issues_batch_mixed_tickets(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 2
|
||||
assert len(docs) == 1 # Only the small ticket should be included
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
|
||||
@@ -123,6 +130,7 @@ def test_fetch_jira_issues_batch_custom_size_limit(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 2
|
||||
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
|
||||
|
||||
class MockHeartbeat(IndexingHeartbeatInterface):
|
||||
class MockHeartbeat(Heartbeat):
|
||||
def __init__(self) -> None:
|
||||
self.call_count = 0
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.call_count += 1
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ def test_chunker_heartbeat(
|
||||
chunker = Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
callback=mock_heartbeat,
|
||||
heartbeat=mock_heartbeat,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk([document])
|
||||
|
||||
80
backend/tests/unit/danswer/indexing/test_heartbeat.py
Normal file
80
backend/tests/unit/danswer/indexing/test_heartbeat.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import IndexAttempt
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_attempt() -> MagicMock:
|
||||
return MagicMock(spec=IndexAttempt)
|
||||
|
||||
|
||||
def test_indexing_heartbeat(
|
||||
mock_db_session: MagicMock, mock_index_attempt: MagicMock
|
||||
) -> None:
|
||||
with patch(
|
||||
"danswer.indexing.indexing_heartbeat.get_index_attempt"
|
||||
) as mock_get_index_attempt:
|
||||
mock_get_index_attempt.return_value = mock_index_attempt
|
||||
|
||||
heartbeat = IndexingHeartbeat(
|
||||
index_attempt_id=1, db_session=mock_db_session, freq=5
|
||||
)
|
||||
|
||||
# Test that heartbeat doesn't update before freq is reached
|
||||
for _ in range(4):
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
# Test that heartbeat updates when freq is reached
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once_with(
|
||||
db_session=mock_db_session, index_attempt_id=1
|
||||
)
|
||||
assert mock_index_attempt.time_updated is not None
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# Reset mock calls
|
||||
mock_db_session.reset_mock()
|
||||
mock_get_index_attempt.reset_mock()
|
||||
|
||||
# Test that heartbeat updates again after freq more calls
|
||||
for _ in range(5):
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
|
||||
with patch(
|
||||
"danswer.indexing.indexing_heartbeat.get_index_attempt"
|
||||
) as mock_get_index_attempt, patch(
|
||||
"danswer.indexing.indexing_heartbeat.logger"
|
||||
) as mock_logger:
|
||||
mock_get_index_attempt.return_value = None
|
||||
|
||||
heartbeat = IndexingHeartbeat(
|
||||
index_attempt_id=1, db_session=mock_db_session, freq=1
|
||||
)
|
||||
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once_with(
|
||||
db_session=mock_db_session, index_attempt_id=1
|
||||
)
|
||||
mock_logger.error.assert_called_once_with(
|
||||
"Index attempt not found, this should not happen!"
|
||||
)
|
||||
mock_db_session.commit.assert_not_called()
|
||||
@@ -324,13 +324,8 @@ def test_lengthy_prefixed_json_with_quotes() -> None:
|
||||
assert quotes[0] == "Document"
|
||||
|
||||
|
||||
def test_json_with_lengthy_prefix_and_quotes() -> None:
|
||||
def test_prefixed_json_with_quotes() -> None:
|
||||
tokens = [
|
||||
"*** Based on the provided documents, there does not appear to be any information ",
|
||||
"directly relevant to answering which documents are my favorite. ",
|
||||
"The documents seem to be focused on describing the Danswer product ",
|
||||
"and its features/use cases. Since I do not have personal preferences ",
|
||||
"for documents, I will provide a general response:\n\n",
|
||||
"```",
|
||||
"json",
|
||||
"\n",
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
For general information, please read the instructions in this [README](https://github.com/danswer-ai/danswer/blob/main/deployment/README.md).
|
||||
|
||||
## Deploy in a system without GPU support
|
||||
This part is elaborated precisely in this [README](https://github.com/danswer-ai/danswer/blob/main/deployment/README.md) in section *Docker Compose*. If you have any questions, please feel free to open an issue or get in touch in slack for support.
|
||||
This part is elaborated precisely in in this [README](https://github.com/danswer-ai/danswer/blob/main/deployment/README.md) in section *Docker Compose*. If you have any questions, please feel free to open an issue or get in touch in slack for support.
|
||||
|
||||
## Deploy in a system with GPU support
|
||||
Running Model servers with GPU support while indexing and querying can result in significant improvements in performance. This is highly recommended if you have access to resources. Currently, Danswer offloads embedding model and tokenizers to the GPU VRAM and the size needed depends on chosen embedding model. For example, the embedding model `nomic-ai/nomic-embed-text-v1` takes up about 1GB of VRAM. That means running this model for inference and embedding pipeline would require roughly 2GB of VRAM.
|
||||
|
||||
4
web/.gitignore
vendored
4
web/.gitignore
vendored
@@ -34,7 +34,3 @@ yarn-error.log*
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
|
||||
/admin_auth.json
|
||||
/build-archive.log
|
||||
|
||||
|
||||
@@ -69,9 +69,6 @@ ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
|
||||
ARG NEXT_PUBLIC_SENTRY_DSN
|
||||
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
|
||||
ARG NEXT_PUBLIC_GTM_ENABLED
|
||||
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
|
||||
|
||||
RUN npx next build
|
||||
|
||||
# Step 2. Production image, copy all the files and run next
|
||||
@@ -137,12 +134,9 @@ ARG NEXT_PUBLIC_POSTHOG_KEY
|
||||
ARG NEXT_PUBLIC_POSTHOG_HOST
|
||||
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
|
||||
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
|
||||
|
||||
ARG NEXT_PUBLIC_SENTRY_DSN
|
||||
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
|
||||
ARG NEXT_PUBLIC_GTM_ENABLED
|
||||
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
|
||||
|
||||
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
||||
# If you want to run this without compose, specify the ports to
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { defineConfig, devices } from "@playwright/test";
|
||||
import { defineConfig } from "@playwright/test";
|
||||
|
||||
export default defineConfig({
|
||||
// Other Playwright config options
|
||||
testDir: "./tests/e2e", // Folder for test files
|
||||
reporter: "list",
|
||||
// Configure paths for screenshots
|
||||
// expect: {
|
||||
// toMatchSnapshot: {
|
||||
@@ -11,30 +11,4 @@ export default defineConfig({
|
||||
// },
|
||||
// reporter: [["html", { outputFolder: "test-results/output/report" }]], // HTML report location
|
||||
// outputDir: "test-results/output/screenshots", // Set output folder for test artifacts
|
||||
projects: [
|
||||
{
|
||||
// dependency for admin workflows
|
||||
name: "admin_setup",
|
||||
testMatch: /.*\admin_auth.setup\.ts/,
|
||||
},
|
||||
{
|
||||
// tests admin workflows
|
||||
name: "chromium-admin",
|
||||
grep: /@admin/,
|
||||
use: {
|
||||
...devices["Desktop Chrome"],
|
||||
// Use prepared auth state.
|
||||
storageState: "admin_auth.json",
|
||||
},
|
||||
dependencies: ["admin_setup"],
|
||||
},
|
||||
{
|
||||
// tests logged out / guest workflows
|
||||
name: "chromium-guest",
|
||||
grep: /@guest/,
|
||||
use: {
|
||||
...devices["Desktop Chrome"],
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -29,7 +29,9 @@ import { deleteApiKey, regenerateApiKey } from "./lib";
|
||||
import { DanswerApiKeyForm } from "./DanswerApiKeyForm";
|
||||
import { APIKey } from "./types";
|
||||
|
||||
const API_KEY_TEXT = `API Keys allow you to access Danswer APIs programmatically. Click the button below to generate a new API Key.`;
|
||||
const API_KEY_TEXT = `
|
||||
API Keys allow you to access Danswer APIs programmatically. Click the button below to generate a new API Key.
|
||||
`;
|
||||
|
||||
function NewApiKeyModal({
|
||||
apiKey,
|
||||
|
||||
@@ -25,7 +25,6 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
|
||||
<CardSection>
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
admin
|
||||
defaultPublic={true}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
|
||||
/>
|
||||
|
||||
@@ -51,7 +51,7 @@ export const SlackTokensForm = ({
|
||||
: "Successfully created Slack Bot!",
|
||||
type: "success",
|
||||
});
|
||||
router.push(`/admin/bots/${encodeURIComponent(botId)}`);
|
||||
router.push(`/admin/bots/${botId}}`);
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
|
||||
@@ -142,8 +142,6 @@ export function CustomLLMProviderUpdateForm({
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
// For custom llm providers, all model names are displayed
|
||||
display_model_names: values.model_names,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -278,6 +278,7 @@ export function LLMProviderUpdateForm({
|
||||
{!(hideAdvanced && llmProviderDescriptor.name != "azure") && (
|
||||
<>
|
||||
<Separator />
|
||||
|
||||
{llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_model_name"
|
||||
@@ -297,6 +298,7 @@ export function LLMProviderUpdateForm({
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.deployment_name_required && (
|
||||
<TextFormField
|
||||
small={hideAdvanced}
|
||||
@@ -305,6 +307,7 @@ export function LLMProviderUpdateForm({
|
||||
placeholder="Deployment Name"
|
||||
/>
|
||||
)}
|
||||
|
||||
{!llmProviderDescriptor.single_model_supported &&
|
||||
(llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
@@ -341,6 +344,7 @@ export function LLMProviderUpdateForm({
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{llmProviderDescriptor.llm_names.length > 0 && (
|
||||
|
||||
@@ -28,6 +28,7 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
|
||||
<CardSection>
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
admin
|
||||
defaultPublic={false}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.CHAT}
|
||||
/>
|
||||
|
||||
@@ -52,7 +52,6 @@ import {
|
||||
useLayoutEffect,
|
||||
useRef,
|
||||
useState,
|
||||
useMemo,
|
||||
} from "react";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
|
||||
@@ -267,6 +266,7 @@ export function ChatPage({
|
||||
availableAssistants[0];
|
||||
|
||||
const noAssistants = liveAssistant == null || liveAssistant == undefined;
|
||||
|
||||
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
|
||||
useEffect(() => {
|
||||
const personaDefault = getLLMProviderOverrideForPersona(
|
||||
@@ -282,7 +282,7 @@ export function ChatPage({
|
||||
);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [liveAssistant, user?.preferences.default_model]);
|
||||
}, [liveAssistant, llmProviders, user?.preferences.default_model]);
|
||||
|
||||
const stopGenerating = () => {
|
||||
const currentSession = currentSessionId();
|
||||
@@ -2007,7 +2007,7 @@ export function ChatPage({
|
||||
{...getRootProps()}
|
||||
>
|
||||
<div
|
||||
className={`w-full h-full flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
|
||||
className={`w-full h-full flex flex-col overflow-y-auto include-scrollbar overflow-x-hidden relative`}
|
||||
ref={scrollableDivRef}
|
||||
>
|
||||
{/* ChatBanner is a custom banner that displays a admin-specified message at
|
||||
|
||||
@@ -23,22 +23,16 @@ import { useRouter } from "next/navigation";
|
||||
import { CHAT_SESSION_ID_KEY } from "@/lib/drag/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
import { ChatSession } from "../interfaces";
|
||||
const FolderItem = ({
|
||||
folder,
|
||||
currentChatId,
|
||||
isInitiallyExpanded,
|
||||
initiallySelected,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
}: {
|
||||
folder: Folder;
|
||||
currentChatId?: string;
|
||||
isInitiallyExpanded: boolean;
|
||||
initiallySelected: boolean;
|
||||
|
||||
showShareModal: ((chatSession: ChatSession) => void) | undefined;
|
||||
showDeleteModal: ((chatSession: ChatSession) => void) | undefined;
|
||||
}) => {
|
||||
const [isExpanded, setIsExpanded] = useState<boolean>(isInitiallyExpanded);
|
||||
const [isEditing, setIsEditing] = useState<boolean>(initiallySelected);
|
||||
@@ -167,9 +161,6 @@ const FolderItem = ({
|
||||
return a.time_created.localeCompare(b.time_created);
|
||||
});
|
||||
|
||||
// Determine whether to show the trash can icon
|
||||
const showTrashIcon = (isHovering && !isEditing) || showDeleteConfirm;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={folder.folder_id}
|
||||
@@ -215,60 +206,55 @@ const FolderItem = ({
|
||||
{editedFolderName || folder.folder_name}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex ml-auto my-auto">
|
||||
<div
|
||||
onClick={handleEditFolderName}
|
||||
className={`hover:bg-black/10 p-1 -m-1 rounded ${
|
||||
isHovering && !isEditing
|
||||
? ""
|
||||
: "opacity-0 pointer-events-none"
|
||||
}`}
|
||||
>
|
||||
<FiEdit2 size={16} />
|
||||
</div>
|
||||
|
||||
<div className="relative">
|
||||
<Popover
|
||||
open={showDeleteConfirm}
|
||||
onOpenChange={setShowDeleteConfirm}
|
||||
content={
|
||||
<div
|
||||
onClick={handleDeleteClick}
|
||||
className={`hover:bg-black/10 p-1 -m-1 rounded ml-2 ${
|
||||
showTrashIcon ? "" : "opacity-0 pointer-events-none"
|
||||
}`}
|
||||
>
|
||||
<FiTrash size={16} />
|
||||
</div>
|
||||
}
|
||||
popover={
|
||||
<div className="p-2 w-[225px] bg-background-100 rounded shadow-lg">
|
||||
<p className="text-sm mb-2">
|
||||
Are you sure you want to delete folder{" "}
|
||||
<i>{folder.folder_name}</i>?
|
||||
</p>
|
||||
<div className="flex justify-end">
|
||||
<button
|
||||
onClick={confirmDelete}
|
||||
className="bg-red-500 hover:bg-red-600 text-white px-2 py-1 rounded text-xs mr-2"
|
||||
>
|
||||
Yes
|
||||
</button>
|
||||
<button
|
||||
onClick={cancelDelete}
|
||||
className="bg-gray-300 hover:bg-gray-200 px-2 py-1 rounded text-xs"
|
||||
>
|
||||
No
|
||||
</button>
|
||||
{isHovering && !isEditing && (
|
||||
<div className="flex ml-auto my-auto">
|
||||
<div
|
||||
onClick={handleEditFolderName}
|
||||
className="hover:bg-black/10 p-1 -m-1 rounded"
|
||||
>
|
||||
<FiEdit2 size={16} />
|
||||
</div>
|
||||
<div className="relative">
|
||||
<Popover
|
||||
open={showDeleteConfirm}
|
||||
onOpenChange={setShowDeleteConfirm}
|
||||
content={
|
||||
<div
|
||||
onClick={handleDeleteClick}
|
||||
className="hover:bg-black/10 p-1 -m-1 rounded ml-2"
|
||||
>
|
||||
<FiTrash size={16} />
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
side="top"
|
||||
align="center"
|
||||
/>
|
||||
}
|
||||
popover={
|
||||
<div className="p-2 w-[225px] bg-background-100 rounded shadow-lg">
|
||||
<p className="text-sm mb-2">
|
||||
Are you sure you want to delete{" "}
|
||||
<i>{folder.folder_name}</i>? All the content inside
|
||||
this folder will also be deleted.
|
||||
</p>
|
||||
<div className="flex justify-end">
|
||||
<button
|
||||
onClick={confirmDelete}
|
||||
className="bg-red-500 hover:bg-red-600 text-white px-2 py-1 rounded text-xs mr-2"
|
||||
>
|
||||
Yes
|
||||
</button>
|
||||
<button
|
||||
onClick={cancelDelete}
|
||||
className="bg-gray-300 hover:bg-gray-200 px-2 py-1 rounded text-xs"
|
||||
>
|
||||
No
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
side="top"
|
||||
align="center"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isEditing && (
|
||||
<div className="flex ml-auto my-auto">
|
||||
@@ -290,8 +276,6 @@ const FolderItem = ({
|
||||
</div>
|
||||
</div>
|
||||
</BasicSelectable>
|
||||
|
||||
{/* Expanded Folder Content */}
|
||||
{isExpanded && folders && (
|
||||
<div className={"ml-2 pl-2 border-l border-border"}>
|
||||
{folders.map((chatSession) => (
|
||||
@@ -300,8 +284,6 @@ const FolderItem = ({
|
||||
chatSession={chatSession}
|
||||
isSelected={chatSession.id === currentChatId}
|
||||
skipGradient={isDragOver}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -315,15 +297,11 @@ export const FolderList = ({
|
||||
currentChatId,
|
||||
openedFolders,
|
||||
newFolderId,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
}: {
|
||||
folders: Folder[];
|
||||
currentChatId?: string;
|
||||
openedFolders?: { [key: number]: boolean };
|
||||
newFolderId: number | null;
|
||||
showShareModal: ((chatSession: ChatSession) => void) | undefined;
|
||||
showDeleteModal: ((chatSession: ChatSession) => void) | undefined;
|
||||
}) => {
|
||||
if (folders.length === 0) {
|
||||
return null;
|
||||
@@ -340,8 +318,6 @@ export const FolderList = ({
|
||||
isInitiallyExpanded={
|
||||
openedFolders ? openedFolders[folder.folder_id] || false : false
|
||||
}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
))}
|
||||
{folders.length == 1 && folders[0].chat_sessions.length == 0 && (
|
||||
|
||||
@@ -76,7 +76,7 @@ export function AssistantsTab({
|
||||
items={assistants.map((a) => a.id.toString())}
|
||||
strategy={verticalListSortingStrategy}
|
||||
>
|
||||
<div className="px-4 pb-2 max-h-[500px] default-scrollbar overflow-y-scroll overflow-x-hidden my-3 grid grid-cols-1 gap-4">
|
||||
<div className="px-4 pb-2 max-h-[500px] include-scrollbar overflow-y-scroll my-3 grid grid-cols-1 gap-4">
|
||||
{assistants.map((assistant) => (
|
||||
<DraggableAssistantCard
|
||||
key={assistant.id.toString()}
|
||||
|
||||
@@ -191,6 +191,7 @@ export function ChatSessionDisplay({
|
||||
</div>
|
||||
</CustomTooltip>
|
||||
)}
|
||||
|
||||
<div>
|
||||
{search ? (
|
||||
showDeleteModal && (
|
||||
|
||||
@@ -74,8 +74,6 @@ export function PagesTab({
|
||||
folders={folders}
|
||||
currentChatId={currentChatId}
|
||||
openedFolders={openedFolders}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -260,29 +260,26 @@
|
||||
}
|
||||
}
|
||||
|
||||
.default-scrollbar::-webkit-scrollbar {
|
||||
.include-scrollbar::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.default-scrollbar::-webkit-scrollbar-track {
|
||||
.include-scrollbar::-webkit-scrollbar-track {
|
||||
background: #f1f1f1;
|
||||
}
|
||||
|
||||
.default-scrollbar::-webkit-scrollbar-thumb {
|
||||
.include-scrollbar::-webkit-scrollbar-thumb {
|
||||
background: #888;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.default-scrollbar::-webkit-scrollbar-thumb:hover {
|
||||
.include-scrollbar::-webkit-scrollbar-thumb:hover {
|
||||
background: #555;
|
||||
}
|
||||
|
||||
.default-scrollbar {
|
||||
.include-scrollbar {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: #888 transparent;
|
||||
overflow: overlay;
|
||||
overflow-y: scroll;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
.inputscroll::-webkit-scrollbar-track {
|
||||
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
} from "@/components/settings/lib";
|
||||
import {
|
||||
CUSTOM_ANALYTICS_ENABLED,
|
||||
GTM_ENABLED,
|
||||
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
|
||||
} from "@/lib/constants";
|
||||
import { Metadata } from "next";
|
||||
@@ -22,7 +21,6 @@ import { getCurrentUserSS } from "@/lib/userSS";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { Suspense } from "react";
|
||||
import PostHogPageView from "./PostHogPageView";
|
||||
import Script from "next/script";
|
||||
|
||||
const inter = Inter({
|
||||
subsets: ["latin"],
|
||||
@@ -82,22 +80,6 @@ export default async function RootLayout({
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{GTM_ENABLED && (
|
||||
<Script
|
||||
id="google-tag-manager"
|
||||
strategy="afterInteractive"
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `
|
||||
(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
|
||||
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
|
||||
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
|
||||
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
|
||||
})(window,document,'script','dataLayer','GTM-PZXS36NG');
|
||||
`,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</head>
|
||||
<body className={`relative ${inter.variable} font-sans`}>
|
||||
<div
|
||||
|
||||
@@ -20,7 +20,7 @@ export function AdvancedOptionsToggle({
|
||||
size="sm"
|
||||
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
|
||||
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
||||
className="text-xs mr-auto !p-0 text-text-950 hover:text-text-500"
|
||||
className="text-xs !p-0 text-text-950 hover:text-text-500"
|
||||
>
|
||||
{title || "Advanced Options"}
|
||||
</Button>
|
||||
|
||||
@@ -57,7 +57,7 @@ const DropdownOption: React.FC<DropdownOptionProps> = ({
|
||||
};
|
||||
|
||||
export function UserDropdown({ page }: { page?: pageType }) {
|
||||
const { user, isCurator } = useUser();
|
||||
const { user } = useUser();
|
||||
const [userInfoVisible, setUserInfoVisible] = useState(false);
|
||||
const userInfoRef = useRef<HTMLDivElement>(null);
|
||||
const router = useRouter();
|
||||
@@ -95,9 +95,7 @@ export function UserDropdown({ page }: { page?: pageType }) {
|
||||
}
|
||||
|
||||
// Construct the current URL
|
||||
const currentUrl = `${pathname}${
|
||||
searchParams.toString() ? `?${searchParams.toString()}` : ""
|
||||
}`;
|
||||
const currentUrl = `${pathname}${searchParams.toString() ? `?${searchParams.toString()}` : ""}`;
|
||||
|
||||
// Encode the current URL to use as a redirect parameter
|
||||
const encodedRedirect = encodeURIComponent(currentUrl);
|
||||
@@ -108,7 +106,9 @@ export function UserDropdown({ page }: { page?: pageType }) {
|
||||
};
|
||||
|
||||
const showAdminPanel = !user || user.role === UserRole.ADMIN;
|
||||
const showCuratorPanel = user && isCurator;
|
||||
const showCuratorPanel =
|
||||
user &&
|
||||
(user.role === UserRole.CURATOR || user.role === UserRole.GLOBAL_CURATOR);
|
||||
const showLogout =
|
||||
user && !checkUserIsNoAuthUser(user.id) && !LOGOUT_DISABLED;
|
||||
|
||||
@@ -244,11 +244,7 @@ export function UserDropdown({ page }: { page?: pageType }) {
|
||||
setShowNotifications(true);
|
||||
}}
|
||||
icon={<BellIcon className="h-5 w-5 my-auto mr-2" />}
|
||||
label={`Notifications ${
|
||||
notifications && notifications.length > 0
|
||||
? `(${notifications.length})`
|
||||
: ""
|
||||
}`}
|
||||
label={`Notifications ${notifications && notifications.length > 0 ? `(${notifications.length})` : ""}`}
|
||||
/>
|
||||
|
||||
{showLogout &&
|
||||
|
||||
@@ -211,7 +211,12 @@ export function TextFormField({
|
||||
<div className={`w-full ${width}`}>
|
||||
<div className="flex gap-x-2 items-center">
|
||||
{!removeLabel && (
|
||||
<Label className={sizeClass.label} small={small}>
|
||||
<Label
|
||||
className={`${
|
||||
small ? "text-text-950" : "text-text-700 font-normal"
|
||||
} ${sizeClass.label}`}
|
||||
small={small}
|
||||
>
|
||||
{label}
|
||||
</Label>
|
||||
)}
|
||||
@@ -657,10 +662,7 @@ export function SelectorFormField({
|
||||
{container && (
|
||||
<SelectContent
|
||||
side={side}
|
||||
className={`
|
||||
${maxHeight ? `${maxHeight}` : "max-h-72"}
|
||||
overflow-y-scroll
|
||||
`}
|
||||
className={maxHeight ? `max-h-[${maxHeight}]` : undefined}
|
||||
container={container}
|
||||
>
|
||||
{options.length === 0 ? (
|
||||
|
||||
@@ -47,7 +47,7 @@ export const AssistantsProvider: React.FC<{
|
||||
const [assistants, setAssistants] = useState<Persona[]>(
|
||||
initialAssistants || []
|
||||
);
|
||||
const { user, isLoadingUser, isAdmin, isCurator } = useUser();
|
||||
const { user, isLoadingUser, isAdmin } = useUser();
|
||||
const [editablePersonas, setEditablePersonas] = useState<Persona[]>([]);
|
||||
const [allAssistants, setAllAssistants] = useState<Persona[]>([]);
|
||||
|
||||
@@ -83,7 +83,7 @@ export const AssistantsProvider: React.FC<{
|
||||
|
||||
useEffect(() => {
|
||||
const fetchPersonas = async () => {
|
||||
if (!isAdmin && !isCurator) {
|
||||
if (!isAdmin) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -101,8 +101,6 @@ export const AssistantsProvider: React.FC<{
|
||||
if (allResponse.ok) {
|
||||
const allPersonas = await allResponse.json();
|
||||
setAllAssistants(allPersonas);
|
||||
} else {
|
||||
console.error("Error fetching personas:", allResponse);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching personas:", error);
|
||||
@@ -110,7 +108,7 @@ export const AssistantsProvider: React.FC<{
|
||||
};
|
||||
|
||||
fetchPersonas();
|
||||
}, [isAdmin, isCurator]);
|
||||
}, [isAdmin]);
|
||||
|
||||
const refreshRecentAssistants = async (currentAssistant: number) => {
|
||||
const response = await fetch("/api/user/recent-assistants", {
|
||||
|
||||
@@ -62,11 +62,7 @@ export const LlmList: React.FC<LlmListProps> = ({
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${
|
||||
scrollable
|
||||
? "max-h-[200px] default-scrollbar overflow-x-hidden"
|
||||
: "max-h-[300px]"
|
||||
} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
|
||||
className={`${scrollable ? "max-h-[200px] include-scrollbar" : "max-h-[300px]"} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
|
||||
>
|
||||
{userDefault && (
|
||||
<button
|
||||
|
||||
@@ -18,7 +18,7 @@ export default function ExceptionTraceModal({
|
||||
title="Full Exception Trace"
|
||||
onOutsideClick={onOutsideClick}
|
||||
>
|
||||
<div className="overflow-y-auto default-scrollbar overflow-x-hidden pr-3 h-full mb-6">
|
||||
<div className="overflow-y-auto include-scrollbar pr-3 h-full mb-6">
|
||||
<div className="mb-6">
|
||||
{!copyClicked ? (
|
||||
<div
|
||||
|
||||
@@ -169,8 +169,10 @@ export const SearchResultsDisplay = ({
|
||||
|
||||
{documents && documents.length > 0 && (
|
||||
<div className="mt-4">
|
||||
<div className="font-bold flex h-12 justify-between text-emphasis border-b mb-3 pb-1 border-border text-lg">
|
||||
<div className="font-bold flex justify-between text-emphasis border-b mb-3 pb-1 border-border text-lg">
|
||||
<p>Results</p>
|
||||
<div className="h-8 w-0 invisible">invisibel text</div>
|
||||
|
||||
{!DISABLE_LLM_DOC_RELEVANCE &&
|
||||
(contentEnriched || searchResponse.additional_relevance) && (
|
||||
<TooltipProvider delayDuration={1000}>
|
||||
|
||||
@@ -698,7 +698,7 @@ export const SearchSection = ({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="absolute default-scrollbar h-screen overflow-y-auto overflow-x-hidden left-0 w-full top-0">
|
||||
<div className="absolute left-0 w-full top-0">
|
||||
<FunctionalHeader
|
||||
sidebarToggled={toggledSidebar}
|
||||
reset={() => setQuery("")}
|
||||
@@ -728,6 +728,7 @@ export const SearchSection = ({
|
||||
} pt-10 relative max-w-[2000px] xl:max-w-[1430px] mx-auto`}
|
||||
>
|
||||
<div className="absolute z-10 mobile:px-4 mobile:max-w-searchbar-max mobile:w-[90%] top-12 desktop:left-4 hidden 2xl:block mobile:left-1/2 mobile:transform mobile:-translate-x-1/2 desktop:w-52 3xl:w-64">
|
||||
{/* Remove this entire SourceSelector block
|
||||
{!settings?.isMobile &&
|
||||
(ccPairs.length > 0 || documentSets.length > 0) && (
|
||||
<SourceSelector
|
||||
@@ -738,6 +739,7 @@ export const SearchSection = ({
|
||||
availableTags={tags}
|
||||
/>
|
||||
)}
|
||||
*/}
|
||||
</div>
|
||||
<div className="absolute left-0 hidden 2xl:block w-52 3xl:w-64"></div>
|
||||
<div className="max-w-searchbar-max w-[90%] mx-auto">
|
||||
@@ -762,16 +764,10 @@ export const SearchSection = ({
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={`mobile:fixed mobile:left-1/2 mobile:transform mobile:-translate-x-1/2 mobile:max-w-search-bar-max mobile:w-[90%] mobile:z-100 mobile:bottom-12`}
|
||||
className={`mobile:max-w-search-bar-max mobile:w-[90%] mobile:z-100`}
|
||||
>
|
||||
<div
|
||||
className={`transition-all duration-500 ease-in-out overflow-hidden
|
||||
${
|
||||
firstSearch
|
||||
? "opacity-100 max-h-[500px]"
|
||||
: "opacity-0 max-h-0"
|
||||
}`}
|
||||
onTransitionEnd={handleTransitionEnd}
|
||||
className={`transition-all duration-500 ease-in-out overflow-hidden opacity-100 max-h-[500px]`}
|
||||
>
|
||||
<div className="mt-48 mb-8 flex justify-center items-center">
|
||||
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message">
|
||||
@@ -802,48 +798,50 @@ export const SearchSection = ({
|
||||
setDefaultOverrides(SEARCH_DEFAULT_OVERRIDES_START);
|
||||
await onSearch({ agentic, offset: 0 });
|
||||
}}
|
||||
finalAvailableDocumentSets={finalAvailableDocumentSets}
|
||||
finalAvailableSources={finalAvailableSources}
|
||||
finalAvailableDocumentSets={[]}
|
||||
finalAvailableSources={[]}
|
||||
filterManager={filterManager}
|
||||
documentSets={documentSets}
|
||||
ccPairs={ccPairs}
|
||||
tags={tags}
|
||||
documentSets={[]}
|
||||
ccPairs={[]}
|
||||
tags={[]}
|
||||
/>
|
||||
</div>
|
||||
{!firstSearch && (
|
||||
<SearchAnswer
|
||||
isFetching={isFetching}
|
||||
dedupedQuotes={dedupedQuotes}
|
||||
searchResponse={searchResponse}
|
||||
setSearchAnswerExpanded={setSearchAnswerExpanded}
|
||||
searchAnswerExpanded={searchAnswerExpanded}
|
||||
setCurrentFeedback={setCurrentFeedback}
|
||||
searchState={searchState}
|
||||
/>
|
||||
)}
|
||||
<div className="mt-6">
|
||||
{!firstSearch && (
|
||||
<SearchAnswer
|
||||
isFetching={isFetching}
|
||||
dedupedQuotes={dedupedQuotes}
|
||||
searchResponse={searchResponse}
|
||||
setSearchAnswerExpanded={setSearchAnswerExpanded}
|
||||
searchAnswerExpanded={searchAnswerExpanded}
|
||||
setCurrentFeedback={setCurrentFeedback}
|
||||
searchState={searchState}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!settings?.isMobile && (
|
||||
<div className="mt-6">
|
||||
{!(agenticResults && isFetching) || disabledAgentic ? (
|
||||
<SearchResultsDisplay
|
||||
searchState={searchState}
|
||||
disabledAgentic={disabledAgentic}
|
||||
contentEnriched={contentEnriched}
|
||||
comments={comments}
|
||||
sweep={sweep}
|
||||
agenticResults={
|
||||
shouldUseAgenticDisplay && !disabledAgentic
|
||||
}
|
||||
performSweep={performSweep}
|
||||
searchResponse={searchResponse}
|
||||
isFetching={isFetching}
|
||||
defaultOverrides={defaultOverrides}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{!settings?.isMobile && (
|
||||
<div className="mt-6">
|
||||
{!(agenticResults && isFetching) || disabledAgentic ? (
|
||||
<SearchResultsDisplay
|
||||
searchState={searchState}
|
||||
disabledAgentic={disabledAgentic}
|
||||
contentEnriched={contentEnriched}
|
||||
comments={comments}
|
||||
sweep={sweep}
|
||||
agenticResults={
|
||||
shouldUseAgenticDisplay && !disabledAgentic
|
||||
}
|
||||
performSweep={performSweep}
|
||||
searchResponse={searchResponse}
|
||||
isFetching={isFetching}
|
||||
defaultOverrides={defaultOverrides}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
|
||||
@@ -67,10 +67,7 @@ export function UserProvider({
|
||||
isLoadingUser,
|
||||
refreshUser,
|
||||
isAdmin: upToDateUser?.role === UserRole.ADMIN,
|
||||
// Curator status applies for either global or basic curator
|
||||
isCurator:
|
||||
upToDateUser?.role === UserRole.CURATOR ||
|
||||
upToDateUser?.role === UserRole.GLOBAL_CURATOR,
|
||||
isCurator: upToDateUser?.role === UserRole.CURATOR,
|
||||
isCloudSuperuser: upToDateUser?.is_cloud_superuser ?? false,
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -22,8 +22,7 @@ export async function fetchAssistantData(): Promise<AssistantData> {
|
||||
// Fetch core assistants data first
|
||||
const [assistants, assistantsFetchError] = await fetchAssistantsSS();
|
||||
if (assistantsFetchError) {
|
||||
// This is not a critical error and occurs when the user is not logged in
|
||||
console.warn(`Failed to fetch assistants - ${assistantsFetchError}`);
|
||||
console.error(`Failed to fetch assistants - ${assistantsFetchError}`);
|
||||
return defaultState;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,10 +36,8 @@ export const SIDEBAR_WIDTH = `w-[350px]`;
|
||||
export const LOGOUT_DISABLED =
|
||||
process.env.NEXT_PUBLIC_DISABLE_LOGOUT?.toLowerCase() === "true";
|
||||
|
||||
// Default sidebar open is true if the environment variable is not set
|
||||
export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
|
||||
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true" ??
|
||||
true;
|
||||
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true";
|
||||
|
||||
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";
|
||||
|
||||
@@ -62,9 +60,6 @@ export const CUSTOM_ANALYTICS_ENABLED = process.env.CUSTOM_ANALYTICS_SECRET_KEY
|
||||
? true
|
||||
: false;
|
||||
|
||||
export const GTM_ENABLED =
|
||||
process.env.NEXT_PUBLIC_GTM_ENABLED?.toLowerCase() === "true";
|
||||
|
||||
export const DISABLE_LLM_DOC_RELEVANCE =
|
||||
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ export function useLlmOverride(
|
||||
modelName: "",
|
||||
}
|
||||
);
|
||||
|
||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
|
||||
currentChatSession && currentChatSession.current_alternate_model
|
||||
? destructureValue(currentChatSession.current_alternate_model)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Connectors - Add Connector",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/add-connector");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Add Connector");
|
||||
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(/^Storage/);
|
||||
}
|
||||
);
|
||||
@@ -1,19 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - User Management - API Keys",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/api-key");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("API Keys");
|
||||
await expect(page.locator("p.text-sm")).toHaveText(
|
||||
/^API Keys allow you to access Danswer APIs programmatically/
|
||||
);
|
||||
await expect(
|
||||
page.getByRole("button", { name: "Create API Key" })
|
||||
).toHaveCount(1);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Custom Assistants - Assistants",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/assistants");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Assistants");
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
/^Assistants are a way to build/
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,24 +0,0 @@
|
||||
// dependency for all admin user tests
|
||||
|
||||
import { test as setup, expect } from "@playwright/test";
|
||||
import { TEST_CREDENTIALS } from "./constants";
|
||||
|
||||
setup("authenticate", async ({ page }) => {
|
||||
const { email, password } = TEST_CREDENTIALS;
|
||||
|
||||
await page.goto("http://localhost:3000/search");
|
||||
|
||||
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fsearch");
|
||||
|
||||
await expect(page).toHaveTitle("Danswer");
|
||||
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
|
||||
// Click the login button
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
await page.waitForURL("http://localhost:3000/search");
|
||||
|
||||
await page.context().storageState({ path: "admin_auth.json" });
|
||||
});
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Custom Assistants - Slack Bots",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/bots");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Slack Bots");
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
/^Setup Slack bots that connect to Danswer./
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,18 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Configuration - Document Processing",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/configuration/document-processing"
|
||||
);
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Document Processing");
|
||||
await expect(page.locator("h3.text-2xl")).toHaveText(
|
||||
"Process with Unstructured API"
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Configuration - LLM",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/configuration/llm");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("LLM Setup");
|
||||
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(
|
||||
"Enabled LLM Providers"
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Configuration - Search Settings",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/configuration/search");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Search Settings");
|
||||
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(
|
||||
"Embedding Model"
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Document Management - Feedback",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/documents/explorer");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Document Explorer");
|
||||
await expect(page.locator("div.flex.text-emphasis.mt-3")).toHaveText(
|
||||
"Search for a document above to modify its boost or hide it from searches."
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,19 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Document Management - Feedback",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/documents/feedback");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Document Feedback");
|
||||
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(
|
||||
"Most Liked Documents"
|
||||
);
|
||||
await expect(page.locator("h1.text-lg").nth(1)).toHaveText(
|
||||
"Most Disliked Documents"
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Document Management - Document Sets",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/documents/sets");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Document Sets");
|
||||
await expect(page.locator("p.text-sm")).toHaveText(
|
||||
/^Document Sets allow you to group logically connected documents into a single bundle./
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - User Management - Groups",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/groups");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Manage User Groups");
|
||||
await expect(
|
||||
page.getByRole("button", { name: "Create New User Group" })
|
||||
).toHaveCount(1);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Connectors - Existing Connectors",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/indexing/status");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Existing Connectors");
|
||||
await expect(page.locator("p.text-sm")).toHaveText(
|
||||
/^It looks like you don't have any connectors setup yet./
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Performance - Custom Analytics",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/performance/custom-analytics");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Custom Analytics");
|
||||
await expect(page.locator("div.font-medium").nth(0)).toHaveText(
|
||||
"Custom Analytics is not enabled."
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,22 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test.describe("Admin Performance Query History", () => {
|
||||
// Ignores the diff for elements targeted by the specified list of selectors
|
||||
// exclude button since they change based on the date
|
||||
test.use({ ignoreSelectors: ["button"] });
|
||||
|
||||
test(
|
||||
"Admin - Performance - Query History",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/performance/query-history");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Query History");
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
"Feedback Type"
|
||||
);
|
||||
}
|
||||
);
|
||||
});
|
||||
@@ -1,20 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test.describe("Admin Performance Usage", () => {
|
||||
// Ignores the diff for elements targeted by the specified list of selectors
|
||||
// exclude button and svg since they change based on the date
|
||||
test.use({ ignoreSelectors: ["button", "svg"] });
|
||||
|
||||
test(
|
||||
"Admin - Performance - Usage Statistics",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
await page.goto("http://localhost:3000/admin/performance/usage");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Usage Statistics");
|
||||
await expect(page.locator("h1.text-lg").nth(0)).toHaveText("Usage");
|
||||
await expect(page.locator("h1.text-lg").nth(1)).toHaveText("Feedback");
|
||||
}
|
||||
);
|
||||
});
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Custom Assistants - Prompt Library",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/prompt-library");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Prompt Library");
|
||||
await expect(page.locator("p.text-sm")).toHaveText(
|
||||
/^Create prompts that can be accessed/
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,19 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Settings - Workspace Settings",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/settings");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Workspace Settings");
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
/^Manage general Danswer settings applicable to all users in the workspace./
|
||||
);
|
||||
await expect(
|
||||
page.getByRole("button", { name: "Set Retention Limit" })
|
||||
).toHaveCount(1);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Custom Assistants - Standard Answers",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/standard-answer");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Standard Answers");
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
/^Manage the standard answers for pre-defined questions./
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,22 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - User Management - Token Rate Limits",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/token-rate-limits");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Token Rate Limits");
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
/^Token rate limits enable you control how many tokens can be spent in a given time period./
|
||||
);
|
||||
await expect(
|
||||
page.getByRole("button", { name: "Create a Token Rate Limit" })
|
||||
).toHaveCount(1);
|
||||
await expect(page.locator("h1.text-lg")).toHaveText(
|
||||
"Global Token Rate Limits"
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Custom Assistants - Tools",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/tools");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Tools");
|
||||
await expect(page.locator("p.text-sm")).toHaveText(
|
||||
"Tools allow assistants to retrieve information or take actions."
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,19 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - User Management - Groups",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/users");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Manage Users");
|
||||
await expect(page.locator("div.font-bold").nth(0)).toHaveText(
|
||||
"Invited Users"
|
||||
);
|
||||
await expect(page.locator("div.font-bold").nth(1)).toHaveText(
|
||||
"Current Users"
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,18 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Admin - Performance - Whitelabeling",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/admin/whitelabeling");
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText("Whitelabeling");
|
||||
await expect(page.locator("div.block").nth(0)).toHaveText(
|
||||
"Application Name"
|
||||
);
|
||||
await expect(page.locator("div.block").nth(1)).toHaveText("Custom Logo");
|
||||
await expect(page.getByRole("button", { name: "Update" })).toHaveCount(1);
|
||||
}
|
||||
);
|
||||
@@ -1,19 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Chat",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/chat");
|
||||
await expect(page.locator("div.text-2xl").nth(0)).toHaveText("General");
|
||||
await expect(page.getByRole("button", { name: "Search S" })).toHaveClass(
|
||||
/text-text-application-untoggled/
|
||||
);
|
||||
await expect(page.getByRole("button", { name: "Chat D" })).toHaveClass(
|
||||
/text-text-application-toggled/
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -1,5 +0,0 @@
|
||||
// constants.js
|
||||
export const TEST_CREDENTIALS = {
|
||||
email: "admin_user@test.com",
|
||||
password: "test",
|
||||
};
|
||||
@@ -1,31 +1,27 @@
|
||||
// ➕ Add this line
|
||||
import { test, expect, takeSnapshot } from "@chromatic-com/playwright";
|
||||
import { TEST_CREDENTIALS } from "./constants";
|
||||
|
||||
// Then use as normal 👇
|
||||
test(
|
||||
"Homepage",
|
||||
{
|
||||
tag: "@guest",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test redirect to login, and redirect to search after login
|
||||
const { email, password } = TEST_CREDENTIALS;
|
||||
test("Homepage", async ({ page }, testInfo) => {
|
||||
// Test redirect to login, and redirect to search after login
|
||||
|
||||
await page.goto("http://localhost:3000/search");
|
||||
// move these into a constants file or test fixture soon
|
||||
let email = "admin_user@test.com";
|
||||
let password = "test";
|
||||
|
||||
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fsearch");
|
||||
await page.goto("http://localhost:3000/search");
|
||||
|
||||
await expect(page).toHaveTitle("Danswer");
|
||||
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fsearch");
|
||||
|
||||
await takeSnapshot(page, "Before login", testInfo);
|
||||
await expect(page).toHaveTitle("Danswer");
|
||||
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
await takeSnapshot(page, "Before login", testInfo);
|
||||
|
||||
// Click the login button
|
||||
await page.click('button[type="submit"]');
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
|
||||
await page.waitForURL("http://localhost:3000/search");
|
||||
}
|
||||
);
|
||||
// Click the login button
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
await page.waitForURL("http://localhost:3000/search");
|
||||
});
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
|
||||
test(
|
||||
"Search",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test simple loading
|
||||
await page.goto("http://localhost:3000/search");
|
||||
await expect(page.locator("div.text-3xl")).toHaveText("Unlock Knowledge");
|
||||
await expect(page.getByRole("button", { name: "Search S" })).toHaveClass(
|
||||
/text-text-application-toggled/
|
||||
);
|
||||
await expect(page.getByRole("button", { name: "Chat D" })).toHaveClass(
|
||||
/text-text-application-untoggled/
|
||||
);
|
||||
}
|
||||
);
|
||||
Reference in New Issue
Block a user