Compare commits

..

4 Commits

Author SHA1 Message Date
pablodanswer
233713cde3 hide animations 2024-11-20 14:47:51 -08:00
Yuhong Sun
c0b17b4c51 k 2024-11-20 10:54:09 -08:00
Yuhong Sun
15f30b0050 k 2024-11-20 10:47:15 -08:00
pablodanswer
39d9df9b1b k 2024-11-20 09:47:22 -08:00
98 changed files with 927 additions and 1816 deletions

View File

@@ -65,7 +65,6 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -13,10 +13,7 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
@@ -198,9 +195,6 @@ jobs:
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests \

View File

@@ -3,7 +3,12 @@ concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -11,8 +16,6 @@ env:
jobs:
playwright-tests:
name: Playwright Tests
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
@@ -105,7 +108,7 @@ jobs:
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -190,8 +193,7 @@ jobs:
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
name: Run Chromatic
needs: playwright-tests
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:

View File

View File

@@ -1,59 +0,0 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -1,6 +1,5 @@
import multiprocessing
from typing import Any
from typing import cast
from celery import bootsteps # type: ignore
from celery import Celery
@@ -15,9 +14,7 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
@@ -98,15 +95,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it

View File

@@ -4,6 +4,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -16,7 +17,6 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -78,7 +78,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
@@ -111,15 +111,10 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
return all_connector_doc_ids

View File

@@ -19,7 +19,7 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_pool import get_redis_client
@@ -118,7 +118,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletePayload(
fence_payload = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)

View File

@@ -29,7 +29,7 @@ from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
logger = setup_logger()
@@ -66,9 +66,9 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
if not source_sync_period:
return True

View File

@@ -3,7 +3,6 @@ from datetime import timezone
from http import HTTPStatus
from time import sleep
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
@@ -17,6 +16,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -33,8 +33,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
@@ -44,11 +42,9 @@ from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
@@ -61,7 +57,7 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingCallback(IndexingHeartbeatInterface):
class RunIndexingCallback(RunIndexingCallbackInterface):
def __init__(
self,
stop_key: str,
@@ -77,7 +73,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = ""
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
@@ -85,17 +80,15 @@ class IndexingCallback(IndexingHeartbeatInterface):
return True
return False
def progress(self, tag: str, amount: int) -> None:
def progress(self, amount: int) -> None:
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
@@ -104,54 +97,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
@@ -162,7 +107,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -172,7 +117,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
@@ -191,18 +135,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
@@ -259,29 +198,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -291,11 +207,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -400,11 +311,10 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -455,8 +365,6 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -477,16 +385,13 @@ def try_creating_indexing_task(
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(None)
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -504,7 +409,7 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -512,7 +417,7 @@ def connector_indexing_proxy_task(
client = SimpleJobClient()
job = client.submit(
connector_indexing_task_wrapper,
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -523,7 +428,7 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -531,7 +436,7 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -555,7 +460,7 @@ def connector_indexing_proxy_task(
if job.status == "error":
task_logger.error(
f"Indexing watchdog - spawned task exceptioned: "
f"Indexing proxy - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -567,7 +472,7 @@ def connector_indexing_proxy_task(
break
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -575,38 +480,6 @@ def connector_indexing_proxy_task(
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -661,7 +534,6 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)
@@ -669,18 +541,18 @@ def connector_indexing_task(
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)
while True:
if not redis_connector_index.fenced: # The fence must exist
# wait for the fence to come up
if not redis_connector_index.fenced:
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload # The payload must exist
payload = redis_connector_index.payload
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
@@ -703,7 +575,7 @@ def connector_indexing_task(
)
break
lock: RedisLock = r.lock(
lock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
@@ -712,7 +584,7 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None
@@ -747,7 +619,7 @@ def connector_indexing_task(
)
# define a callback class
callback = IndexingCallback(
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,

View File

@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -277,7 +277,7 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
callback = IndexingCallback(
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,

View File

@@ -5,6 +5,7 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -46,10 +47,13 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@@ -645,26 +649,20 @@ def monitor_ccpair_indexing_taskset(
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
result_state = result.state
if (
result_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if status_int is None: # completion signal not set ... check for errors
# If we get here, and then the task both sets the completion signal and finishes,
# we will incorrectly abort the task. We must check result state, then check
# get_completion again to avoid the race condition.
if result_state in READY_STATES:
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
@@ -699,6 +697,37 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
if r.exists(fence_key):
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
@@ -750,6 +779,25 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"permissions_sync={n_permissions_sync} "
)
# Fail any index attempts in the DB that don't have fences
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)

View File

@@ -1,5 +1,7 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -29,7 +31,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
@@ -40,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -91,7 +106,7 @@ def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -123,7 +138,13 @@ def _run_indexing(
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
callback=callback,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
@@ -136,7 +157,6 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -208,9 +228,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
raise RuntimeError("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -271,7 +289,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch))
callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -401,7 +419,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
try:
if is_ee:

View File

@@ -7,9 +7,9 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
@@ -70,7 +70,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self._confluence_client: OnyxConfluence | None = None
self.confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@@ -97,44 +97,39 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = ""
if labels_to_skip:
labels_to_skip = list(set(labels_to_skip))
comma_separated_labels = ",".join(
f"'{quote(label)}'" for label in labels_to_skip
)
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self._confluence_client = build_confluence_client(
credentials=credentials,
self.confluence_client = build_confluence_client(
credentials_json=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
return None
def _get_comment_string_for_page_id(self, page_id: str) -> str:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comment in self.confluence_client.paginated_cql_retrieval(
for comments in self.confluence_client.paginated_cql_page_retrieval(
cql=comment_cql,
expand=expand,
):
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
fetched_titles=set(),
)
for comment in comments:
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
)
return comment_string
@@ -146,6 +141,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
"""
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
@@ -155,19 +153,16 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
self.confluence_client, confluence_object
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
confluence_client=self.confluence_client, attachment=confluence_object
self.confluence_client, confluence_object
)
if object_text is None:
# This only happens for attachments that are not parseable
return None
# Get space name
@@ -198,39 +193,44 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
for page in page_batch:
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
for attachments in self.confluence_client.paginated_cql_page_retrieval(
cql=attachment_cql,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
for attachment in attachments:
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -255,47 +255,52 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
for pages in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
for page in pages:
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
attachment["_links"]["webui"],
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachments in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
for attachment in attachments:
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []

View File

@@ -20,10 +20,6 @@ F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class ConfluenceRateLimitError(Exception):
pass
@@ -84,7 +80,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
TIMEOUT = 3600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
@@ -99,10 +95,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
@@ -149,7 +141,7 @@ class OnyxConfluence(Confluence):
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[dict[str, Any]]:
) -> Iterator[list[dict[str, Any]]]:
"""
This will paginate through the top level query.
"""
@@ -161,43 +153,46 @@ class OnyxConfluence(Confluence):
while url_suffix:
try:
logger.debug(f"Making confluence call to {url_suffix}")
next_response = self.get(url_suffix)
except Exception as e:
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
# yield the results individually
yield from next_response.get("results", [])
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_cql_retrieval(
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The content/search endpoint can be used to fetch pages, attachments, and comments.
"""
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
yield from self._paginate_url(
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
@@ -206,7 +201,7 @@ class OnyxConfluence(Confluence):
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
) -> Iterator[list[dict[str, Any]]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
@@ -226,110 +221,6 @@ class OnyxConfluence(Confluence):
for item in data:
_traverse_and_update(item)
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
_traverse_and_update(confluence_object)
yield confluence_object
def paginated_cql_user_retrieval(
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
def paginated_groups_by_user_retrieval(
self,
user: dict[str, Any],
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user["accountId"] if self.cloud else user["userKey"]
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
yield from self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_without_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
)
spaces = confluence_client_without_retries.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
)
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results

View File

@@ -2,7 +2,6 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import quote
import bs4
@@ -72,9 +71,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -82,7 +79,7 @@ def extract_text_from_confluence_html(
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
@@ -104,72 +101,38 @@ def extract_text_from_confluence_html(
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
for html_page_reference in soup.findAll("ri:page"):
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_title = html_page_reference.attrs["ri:content-title"]
if not page_title:
continue
page_query = f"type=page and title='{page_title}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
for page_batch in confluence_client.paginated_cql_page_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
page_contents = page_batch[0]
break
except Exception as e:
except Exception:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
f"Error getting page contents for object {confluence_object}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
confluence_client, page_contents
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
@@ -269,3 +232,20 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

View File

@@ -1,8 +1,8 @@
import os
from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import Issue
@@ -12,93 +12,129 @@ from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.danswer_jira.utils import best_effort_basic_expert_info
from danswer.connectors.danswer_jira.utils import best_effort_get_field_from_issue
from danswer.connectors.danswer_jira.utils import build_jira_client
from danswer.connectors.danswer_jira.utils import build_jira_url
from danswer.connectors.danswer_jira.utils import extract_jira_project
from danswer.connectors.danswer_jira.utils import extract_text_from_adf
from danswer.connectors.danswer_jira.utils import get_comment_strs
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def _paginate_jql_search(
jira_client: JIRA,
jql: str,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise Exception(f"Found Jira object not of type Issue: {issue}")
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
if len(issues) < max_results:
break
return jira_base, jira_project
start += max_results
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in jira.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
def fetch_jira_issues_batch(
jira_client: JIRA,
jql: str,
batch_size: int,
start_index: int,
jira_client: JIRA,
batch_size: int = INDEX_BATCH_SIZE,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
) -> Iterable[Document]:
for issue in _paginate_jql_search(
jira_client=jira_client,
jql=jql,
max_results=batch_size,
):
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
) -> tuple[list[Document], int]:
doc_batch = []
batch = jira_client.search_issues(
jql,
startAt=start_index,
maxResults=batch_size,
)
for jira in batch:
if type(jira) != Issue:
logger.warning(f"Found Jira object not of type Issue {jira}")
continue
if labels_to_skip and any(
label in jira.fields.labels for label in labels_to_skip
):
logger.info(
f"Skipping {jira.key} because it has a label to skip. Found "
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = (
issue.fields.description
jira.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
@@ -106,53 +142,66 @@ def fetch_jira_issues_batch(
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
people.add(
BasicExpertInfo(
display_name=jira.fields.creator.displayName,
email=jira.fields.creator.emailAddress,
)
)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
people.add(
BasicExpertInfo(
display_name=jira.fields.assignee.displayName, # type: ignore
email=jira.fields.assignee.emailAddress, # type: ignore
)
)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
yield Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=issue.fields.summary,
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
)
)
return doc_batch, len(batch)
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
class JiraConnector(LoadConnector, PollConnector):
def __init__(
self,
jira_project_url: str,
@@ -164,8 +213,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
self.jira_client: JIRA | None = None
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
@@ -174,45 +223,54 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def comment_email_blacklist(self) -> tuple:
return tuple(email.strip() for email in self._comment_email_blacklist)
@property
def jira_client(self) -> JIRA:
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
return self._jira_client
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
return f'"{self._jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
)
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
self.jira_client = JIRA(
basic_auth=(email, api_token),
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
self.jira_client = JIRA(
token_auth=api_token,
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
jql = f"project = {self.quoted_jira_project}"
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
)
yield document_batch
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
@@ -220,54 +278,31 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {self.quoted_jira_project} AND "
f"project = {quoted_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"
slim_doc_batch = []
for issue in _paginate_jql_search(
jira_client=self.jira_client,
jql=jql,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
issue_key = best_effort_get_field_from_issue(issue, "key")
id = build_jira_url(self.jira_client, issue_key)
slim_doc_batch.append(
SlimDocument(
id=id,
perm_sync_data=None,
)
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=jql,
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
)
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch
slim_doc_batch = []
yield slim_doc_batch
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
if __name__ == "__main__":

View File

@@ -1,136 +1,17 @@
"""Module with custom fields processing functions"""
import os
from typing import Any
from typing import List
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import CustomFieldOption
from jira.resources import Issue
from jira.resources import User
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
display_name = None
email = None
if hasattr(obj, "display_name"):
display_name = obj.display_name
else:
display_name = obj.get("displayName")
if hasattr(obj, "emailAddress"):
email = obj.emailAddress
else:
email = obj.get("emailAddress")
if not email and not display_name:
return None
return BasicExpertInfo(display_name=display_name, email=email)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
return f"{jira_client.client_info()}/browse/{issue_key}"
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
return JIRA(
basic_auth=(email, api_token),
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
return JIRA(
token_auth=api_token,
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
return jira_base, jira_project
def get_comment_strs(
issue: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in issue.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
class CustomFieldExtractor:
@staticmethod
def _process_custom_field_value(value: Any) -> str:

View File

@@ -2,7 +2,6 @@ import io
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -49,67 +48,6 @@ def _extract_sections_basic(
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
@@ -127,7 +65,6 @@ def _extract_sections_basic(
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,

View File

@@ -197,9 +197,7 @@ class SlackbotHandler:
return
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
tokens_changed = (
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
)
tokens_changed = slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
if not tokens_exist or tokens_changed:
if tokens_exist:
logger.info(

View File

@@ -67,13 +67,6 @@ def create_index_attempt(
return new_attempt.id
def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt:
db_session.delete(index_attempt)
db_session.commit()
def mock_successful_index_attempt(
connector_credential_pair_id: int,
search_settings_id: int,

View File

@@ -1181,7 +1181,7 @@ class LLMProvider(Base):
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Models to actually display to users
# Models to actually disp;aly to users
# If nulled out, we assume in the application logic we should present all
display_model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True

View File

@@ -259,6 +259,7 @@ def get_personas(
) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:

View File

@@ -10,7 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@@ -125,7 +125,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -134,7 +134,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.callback = callback
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -356,14 +356,9 @@ class Chunker:
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("Chunker.chunk: Stop signal detected")
final_chunks.extend(self._handle_single_document(document))
chunks = self._handle_single_document(document)
final_chunks.extend(chunks)
if self.callback:
self.callback.progress("Chunker.chunk", len(chunks))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -2,7 +2,7 @@ from abc import ABC
from abc import abstractmethod
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -34,7 +34,7 @@ class IndexingEmbedder(ABC):
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
callback: IndexingHeartbeatInterface | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -60,7 +60,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
callback=callback,
heartbeat=heartbeat,
)
@abstractmethod
@@ -83,7 +83,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -95,7 +95,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url,
api_version,
deployment_name,
callback,
heartbeat,
)
@log_function_time()
@@ -201,9 +201,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls,
search_settings: SearchSettings,
callback: IndexingHeartbeatInterface | None = None,
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -215,5 +213,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
callback=callback,
heartbeat=heartbeat,
)

View File

@@ -1,15 +1,41 @@
from abc import ABC
from abc import abstractmethod
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
@abstractmethod
def progress(self, tag: str, amount: int) -> None:
"""Send progress updates to the caller."""
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -34,7 +34,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -414,7 +414,6 @@ def build_indexing_pipeline(
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
@@ -441,8 +440,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -231,16 +231,16 @@ class QuotesProcessor:
model_previous = self.model_output
self.model_output += token
if not self.found_answer_start:
m = answer_pattern.search(self.model_output)
if m:
self.found_answer_start = True
# Prevent heavy cases of hallucinations
if self.is_json_prompt and len(self.model_output) > 400:
self.found_answer_end = True
if self.is_json_prompt and len(self.model_output) > 70:
logger.warning("LLM did not produce json as prompted")
logger.debug("Model output thus far:", self.model_output)
self.found_answer_end = True
return
remaining = self.model_output[m.end() :]

View File

@@ -16,7 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -99,7 +99,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
@@ -116,7 +116,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.callback = callback
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -160,10 +160,6 @@ class EmbeddingModel:
embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
@@ -183,8 +179,8 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@@ -17,7 +17,7 @@ from danswer.db.document import construct_document_select_for_connector_credenti
from danswer.db.models import Document as DbDocument
class RedisConnectorDeletePayload(BaseModel):
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime
@@ -54,18 +54,20 @@ class RedisConnectorDelete:
return False
@property
def payload(self) -> RedisConnectorDeletePayload | None:
def payload(self) -> RedisConnectorDeletionFenceData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorDeletePayload.model_validate_json(cast(str, fence_str))
payload = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(self, payload: RedisConnectorDeletePayload | None) -> None:
def set_fence(self, payload: RedisConnectorDeletionFenceData | None) -> None:
if not payload:
self.redis.delete(self.fence_key)
return

View File

@@ -30,6 +30,7 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/llm")
basic_router = APIRouter(prefix="/llm")

View File

@@ -4,10 +4,6 @@ import re
import string
from urllib.parse import quote
from danswer.utils.logger import setup_logger
logger = setup_logger(__name__)
ESCAPE_SEQUENCE_RE = re.compile(
r"""
@@ -81,8 +77,7 @@ def extract_embedded_json(s: str) -> dict:
last_brace_index = s.rfind("}")
if first_brace_index == -1 or last_brace_index == -1:
logger.warning("No valid json found, assuming answer is entire string")
return {"answer": s, "quotes": []}
raise ValueError("No valid json found")
json_str = s[first_brace_index : last_brace_index + 1]
try:

View File

@@ -411,8 +411,6 @@ def _validate_curator_status__no_commit(
.all()
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@@ -438,15 +436,6 @@ def update_user_curator_relationship(
user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
if user.role == UserRole.ADMIN:
raise ValueError(
f"User '{user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=set_curator_request.user_id,

View File

@@ -1,5 +1,7 @@
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
from typing import Any
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import get_user_email_from_username__server
from danswer.db.models import ConnectorCredentialPair
from danswer.utils.logger import setup_logger
@@ -9,30 +11,26 @@ from ee.danswer.db.external_perm import ExternalUserGroup
logger = setup_logger()
def _build_group_member_email_map(
def _get_group_members_email_paginated(
confluence_client: OnyxConfluence,
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
user = user_result["user"]
email = user.get("email")
group_name: str,
) -> set[str]:
members: list[dict[str, Any]] = []
for member_batch in confluence_client.paginated_group_members_retrieval(group_name):
members.extend(member_batch)
group_member_emails: set[str] = set()
for member in members:
email = member.get("email")
if not email:
# This field is only present in Confluence Server
user_name = user.get("username")
# If it is present, try to get the email using a Server-specific method
user_name = member.get("username")
if user_name:
email = get_user_email_from_username__server(
confluence_client=confluence_client,
user_name=user_name,
)
if not email:
# If we still don't have an email, skip this user
continue
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
if email:
group_member_emails.add(email)
return group_member_emails
@@ -40,20 +38,31 @@ def _build_group_member_email_map(
def confluence_group_sync(
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
confluence_client = build_confluence_client(
credentials=cc_pair.credential.credential_json,
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
credentials_json=cc_pair.credential.credential_json,
is_cloud=is_cloud,
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,
)
# Get all group names
group_names: list[str] = []
for group_batch in confluence_client.paginated_groups_retrieval():
for group in group_batch:
if group_name := group.get("name"):
group_names.append(group_name)
# For each group name, get all members and create a danswer group
danswer_groups: list[ExternalUserGroup] = []
for group_id, group_member_emails in group_member_email_map.items():
for group_name in group_names:
group_member_emails = _get_group_members_email_paginated(
confluence_client, group_name
)
if not group_member_emails:
continue
danswer_groups.append(
ExternalUserGroup(
id=group_id,
id=group_name,
user_emails=list(group_member_emails),
)
)

View File

@@ -55,12 +55,7 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
DocumentSource.SLACK: 5 * 60,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 60 seconds
DocumentSource.GOOGLE_DRIVE: 60,
DocumentSource.CONFLUENCE: 60,
}
EXTERNAL_GROUP_SYNC_PERIOD: int = 30 # 30 seconds
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:

View File

@@ -1,84 +0,0 @@
import os
from datetime import datetime
from datetime import timezone
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_connector_creation(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connectors
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
cc_pair_info = CCPairManager.get_single(
cc_pair_1.id, user_performing_action=admin_user
)
assert cc_pair_info
assert cc_pair_info.creator
assert str(cc_pair_info.creator) == admin_user.id
assert cc_pair_info.creator_email == admin_user.email
def test_overlapping_connector_creation(reset: None) -> None:
"""Tests that connectors indexing the same documents don't interfere with each other.
A previous bug involved document by cc pair entries not being added for new connectors
when the docs existed already via another connector and were up to date relative to the source.
"""
admin_user: DATestUser = UserManager.create(name="admin_user")
config = {
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
"space": "DailyConne",
"is_cloud": True,
"page_id": "",
}
credential = {
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
}
# store the time before we create the connector so that we know after
# when the indexing should have started
now = datetime.now(timezone.utc)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.CONFLUENCE,
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
cc_pair_1, now, timeout=120, user_performing_action=admin_user
)
now = datetime.now(timezone.utc)
cc_pair_2 = CCPairManager.create_from_scratch(
source=DocumentSource.CONFLUENCE,
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
cc_pair_2, now, timeout=120, user_performing_action=admin_user
)
info_1 = CCPairManager.get_single(cc_pair_1.id, user_performing_action=admin_user)
assert info_1
info_2 = CCPairManager.get_single(cc_pair_2.id, user_performing_action=admin_user)
assert info_2
assert info_1.num_docs_indexed == info_2.num_docs_indexed

View File

@@ -29,6 +29,78 @@ from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
def test_connector_creation(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connectors
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
cc_pair_info = CCPairManager.get_single(
cc_pair_1.id, user_performing_action=admin_user
)
assert cc_pair_info
assert cc_pair_info.creator
assert str(cc_pair_info.creator) == admin_user.id
assert cc_pair_info.creator_email == admin_user.email
# TODO(rkuo): will enable this once i have credentials on github
# def test_overlapping_connector_creation(reset: None) -> None:
# # Creating an admin user (first user created is automatically an admin)
# admin_user: DATestUser = UserManager.create(name="admin_user")
# config = {
# "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
# "space": os.environ["CONFLUENCE_TEST_SPACE"],
# "is_cloud": True,
# "page_id": "",
# }
# credential = {
# "confluence_username": os.environ["CONFLUENCE_USER_NAME"],
# "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
# }
# # store the time before we create the connector so that we know after
# # when the indexing should have started
# now = datetime.now(timezone.utc)
# # create connector
# cc_pair_1 = CCPairManager.create_from_scratch(
# source=DocumentSource.CONFLUENCE,
# connector_specific_config=config,
# credential_json=credential,
# user_performing_action=admin_user,
# )
# CCPairManager.wait_for_indexing(
# cc_pair_1, now, timeout=60, user_performing_action=admin_user
# )
# cc_pair_2 = CCPairManager.create_from_scratch(
# source=DocumentSource.CONFLUENCE,
# connector_specific_config=config,
# credential_json=credential,
# user_performing_action=admin_user,
# )
# CCPairManager.wait_for_indexing(
# cc_pair_2, now, timeout=60, user_performing_action=admin_user
# )
# info_1 = CCPairManager.get_single(cc_pair_1.id)
# assert info_1
# info_2 = CCPairManager.get_single(cc_pair_2.id)
# assert info_2
# assert info_1.num_docs_indexed == info_2.num_docs_indexed
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
@@ -17,48 +18,49 @@ def mock_jira_client() -> MagicMock:
@pytest.fixture
def mock_issue_small() -> MagicMock:
issue = MagicMock(spec=Issue)
fields = MagicMock()
fields.description = "Small description"
fields.comment = MagicMock()
fields.comment.comments = [
issue = MagicMock()
issue.key = "SMALL-1"
issue.fields.description = "Small description"
issue.fields.comment.comments = [
MagicMock(body="Small comment 1"),
MagicMock(body="Small comment 2"),
]
fields.creator = MagicMock()
fields.creator.displayName = "John Doe"
fields.creator.emailAddress = "john@example.com"
fields.summary = "Small Issue"
fields.updated = "2023-01-01T00:00:00+0000"
fields.labels = []
issue.fields = fields
issue.key = "SMALL-1"
issue.fields.creator.displayName = "John Doe"
issue.fields.creator.emailAddress = "john@example.com"
issue.fields.summary = "Small Issue"
issue.fields.updated = "2023-01-01T00:00:00+0000"
issue.fields.labels = []
return issue
@pytest.fixture
def mock_issue_large() -> MagicMock:
issue = MagicMock(spec=Issue)
fields = MagicMock()
fields.description = "a" * 99_000
fields.comment = MagicMock()
fields.comment.comments = [
# This will be larger than 100KB
issue = MagicMock()
issue.key = "LARGE-1"
issue.fields.description = "a" * 99_000
issue.fields.comment.comments = [
MagicMock(body="Large comment " * 1000),
MagicMock(body="Another large comment " * 1000),
]
fields.creator = MagicMock()
fields.creator.displayName = "Jane Doe"
fields.creator.emailAddress = "jane@example.com"
fields.summary = "Large Issue"
fields.updated = "2023-01-02T00:00:00+0000"
fields.labels = []
issue.fields = fields
issue.key = "LARGE-1"
issue.fields.creator.displayName = "Jane Doe"
issue.fields.creator.emailAddress = "jane@example.com"
issue.fields.summary = "Large Issue"
issue.fields.updated = "2023-01-02T00:00:00+0000"
issue.fields.labels = []
return issue
@pytest.fixture
def patched_type() -> Callable[[Any], type]:
def _patched_type(obj: Any) -> type:
if isinstance(obj, MagicMock):
return Issue
return type(obj)
return _patched_type
@pytest.fixture
def mock_jira_api_version() -> Generator[Any, Any, Any]:
with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"):
@@ -67,9 +69,11 @@ def mock_jira_api_version() -> Generator[Any, Any, Any]:
@pytest.fixture
def patched_environment(
patched_type: type,
mock_jira_api_version: MockFixture,
) -> Generator[Any, Any, Any]:
yield
with patch("danswer.connectors.danswer_jira.connector.type", patched_type):
yield
def test_fetch_jira_issues_batch_small_ticket(
@@ -79,8 +83,9 @@ def test_fetch_jira_issues_batch_small_ticket(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 1
assert len(docs) == 1
assert docs[0].id.endswith("/SMALL-1")
assert "Small description" in docs[0].sections[0].text
@@ -95,8 +100,9 @@ def test_fetch_jira_issues_batch_large_ticket(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 1
assert len(docs) == 0 # The large ticket should be skipped
@@ -108,8 +114,9 @@ def test_fetch_jira_issues_batch_mixed_tickets(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 2
assert len(docs) == 1 # Only the small ticket should be included
assert docs[0].id.endswith("/SMALL-1")
@@ -123,6 +130,7 @@ def test_fetch_jira_issues_batch_custom_size_limit(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 2
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit

View File

@@ -1,16 +1,15 @@
from typing import Any
import pytest
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
class MockHeartbeat(IndexingHeartbeatInterface):
class MockHeartbeat(Heartbeat):
def __init__(self) -> None:
self.call_count = 0
def should_stop(self) -> bool:
return False
def progress(self, tag: str, amount: int) -> None:
def heartbeat(self, metadata: Any = None) -> None:
self.call_count += 1

View File

@@ -74,7 +74,7 @@ def test_chunker_heartbeat(
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
callback=mock_heartbeat,
heartbeat=mock_heartbeat,
)
chunks = chunker.chunk([document])

View File

@@ -0,0 +1,80 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from danswer.db.index_attempt import IndexAttempt
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
@pytest.fixture
def mock_db_session() -> MagicMock:
return MagicMock(spec=Session)
@pytest.fixture
def mock_index_attempt() -> MagicMock:
return MagicMock(spec=IndexAttempt)
def test_indexing_heartbeat(
mock_db_session: MagicMock, mock_index_attempt: MagicMock
) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt:
mock_get_index_attempt.return_value = mock_index_attempt
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=5
)
# Test that heartbeat doesn't update before freq is reached
for _ in range(4):
heartbeat.heartbeat()
mock_db_session.commit.assert_not_called()
# Test that heartbeat updates when freq is reached
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
assert mock_index_attempt.time_updated is not None
mock_db_session.commit.assert_called_once()
# Reset mock calls
mock_db_session.reset_mock()
mock_get_index_attempt.reset_mock()
# Test that heartbeat updates again after freq more calls
for _ in range(5):
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt, patch(
"danswer.indexing.indexing_heartbeat.logger"
) as mock_logger:
mock_get_index_attempt.return_value = None
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=1
)
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
mock_logger.error.assert_called_once_with(
"Index attempt not found, this should not happen!"
)
mock_db_session.commit.assert_not_called()

View File

@@ -324,13 +324,8 @@ def test_lengthy_prefixed_json_with_quotes() -> None:
assert quotes[0] == "Document"
def test_json_with_lengthy_prefix_and_quotes() -> None:
def test_prefixed_json_with_quotes() -> None:
tokens = [
"*** Based on the provided documents, there does not appear to be any information ",
"directly relevant to answering which documents are my favorite. ",
"The documents seem to be focused on describing the Danswer product ",
"and its features/use cases. Since I do not have personal preferences ",
"for documents, I will provide a general response:\n\n",
"```",
"json",
"\n",

View File

@@ -5,7 +5,7 @@
For general information, please read the instructions in this [README](https://github.com/danswer-ai/danswer/blob/main/deployment/README.md).
## Deploy in a system without GPU support
This part is elaborated precisely in this [README](https://github.com/danswer-ai/danswer/blob/main/deployment/README.md) in section *Docker Compose*. If you have any questions, please feel free to open an issue or get in touch in slack for support.
This part is elaborated precisely in in this [README](https://github.com/danswer-ai/danswer/blob/main/deployment/README.md) in section *Docker Compose*. If you have any questions, please feel free to open an issue or get in touch in slack for support.
## Deploy in a system with GPU support
Running Model servers with GPU support while indexing and querying can result in significant improvements in performance. This is highly recommended if you have access to resources. Currently, Danswer offloads embedding model and tokenizers to the GPU VRAM and the size needed depends on chosen embedding model. For example, the embedding model `nomic-ai/nomic-embed-text-v1` takes up about 1GB of VRAM. That means running this model for inference and embedding pipeline would require roughly 2GB of VRAM.

4
web/.gitignore vendored
View File

@@ -34,7 +34,3 @@ yarn-error.log*
# typescript
*.tsbuildinfo
next-env.d.ts
/admin_auth.json
/build-archive.log

View File

@@ -69,9 +69,6 @@ ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
ARG NEXT_PUBLIC_GTM_ENABLED
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
RUN npx next build
# Step 2. Production image, copy all the files and run next
@@ -137,12 +134,9 @@ ARG NEXT_PUBLIC_POSTHOG_KEY
ARG NEXT_PUBLIC_POSTHOG_HOST
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
ARG NEXT_PUBLIC_GTM_ENABLED
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
# Note: Don't expose ports here, Compose will handle that for us if necessary.
# If you want to run this without compose, specify the ports to

View File

@@ -1,8 +1,8 @@
import { defineConfig, devices } from "@playwright/test";
import { defineConfig } from "@playwright/test";
export default defineConfig({
// Other Playwright config options
testDir: "./tests/e2e", // Folder for test files
reporter: "list",
// Configure paths for screenshots
// expect: {
// toMatchSnapshot: {
@@ -11,30 +11,4 @@ export default defineConfig({
// },
// reporter: [["html", { outputFolder: "test-results/output/report" }]], // HTML report location
// outputDir: "test-results/output/screenshots", // Set output folder for test artifacts
projects: [
{
// dependency for admin workflows
name: "admin_setup",
testMatch: /.*\admin_auth.setup\.ts/,
},
{
// tests admin workflows
name: "chromium-admin",
grep: /@admin/,
use: {
...devices["Desktop Chrome"],
// Use prepared auth state.
storageState: "admin_auth.json",
},
dependencies: ["admin_setup"],
},
{
// tests logged out / guest workflows
name: "chromium-guest",
grep: /@guest/,
use: {
...devices["Desktop Chrome"],
},
},
],
});

View File

@@ -29,7 +29,9 @@ import { deleteApiKey, regenerateApiKey } from "./lib";
import { DanswerApiKeyForm } from "./DanswerApiKeyForm";
import { APIKey } from "./types";
const API_KEY_TEXT = `API Keys allow you to access Danswer APIs programmatically. Click the button below to generate a new API Key.`;
const API_KEY_TEXT = `
API Keys allow you to access Danswer APIs programmatically. Click the button below to generate a new API Key.
`;
function NewApiKeyModal({
apiKey,

View File

@@ -25,7 +25,6 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
<CardSection>
<AssistantEditor
{...values}
admin
defaultPublic={true}
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
/>

View File

@@ -51,7 +51,7 @@ export const SlackTokensForm = ({
: "Successfully created Slack Bot!",
type: "success",
});
router.push(`/admin/bots/${encodeURIComponent(botId)}`);
router.push(`/admin/bots/${botId}}`);
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;

View File

@@ -142,8 +142,6 @@ export function CustomLLMProviderUpdateForm({
},
body: JSON.stringify({
...values,
// For custom llm providers, all model names are displayed
display_model_names: values.model_names,
custom_config: customConfigProcessing(values.custom_config_list),
}),
});

View File

@@ -278,6 +278,7 @@ export function LLMProviderUpdateForm({
{!(hideAdvanced && llmProviderDescriptor.name != "azure") && (
<>
<Separator />
{llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
name="default_model_name"
@@ -297,6 +298,7 @@ export function LLMProviderUpdateForm({
placeholder="E.g. gpt-4"
/>
)}
{llmProviderDescriptor.deployment_name_required && (
<TextFormField
small={hideAdvanced}
@@ -305,6 +307,7 @@ export function LLMProviderUpdateForm({
placeholder="Deployment Name"
/>
)}
{!llmProviderDescriptor.single_model_supported &&
(llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
@@ -341,6 +344,7 @@ export function LLMProviderUpdateForm({
/>
</>
)}
{showAdvancedOptions && (
<>
{llmProviderDescriptor.llm_names.length > 0 && (

View File

@@ -28,6 +28,7 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
<CardSection>
<AssistantEditor
{...values}
admin
defaultPublic={false}
redirectType={SuccessfulPersonaUpdateRedirectType.CHAT}
/>

View File

@@ -52,7 +52,6 @@ import {
useLayoutEffect,
useRef,
useState,
useMemo,
} from "react";
import { usePopup } from "@/components/admin/connectors/Popup";
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
@@ -267,6 +266,7 @@ export function ChatPage({
availableAssistants[0];
const noAssistants = liveAssistant == null || liveAssistant == undefined;
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
useEffect(() => {
const personaDefault = getLLMProviderOverrideForPersona(
@@ -282,7 +282,7 @@ export function ChatPage({
);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [liveAssistant, user?.preferences.default_model]);
}, [liveAssistant, llmProviders, user?.preferences.default_model]);
const stopGenerating = () => {
const currentSession = currentSessionId();
@@ -2007,7 +2007,7 @@ export function ChatPage({
{...getRootProps()}
>
<div
className={`w-full h-full flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
className={`w-full h-full flex flex-col overflow-y-auto include-scrollbar overflow-x-hidden relative`}
ref={scrollableDivRef}
>
{/* ChatBanner is a custom banner that displays a admin-specified message at

View File

@@ -23,22 +23,16 @@ import { useRouter } from "next/navigation";
import { CHAT_SESSION_ID_KEY } from "@/lib/drag/constants";
import Cookies from "js-cookie";
import { Popover } from "@/components/popover/Popover";
import { ChatSession } from "../interfaces";
const FolderItem = ({
folder,
currentChatId,
isInitiallyExpanded,
initiallySelected,
showShareModal,
showDeleteModal,
}: {
folder: Folder;
currentChatId?: string;
isInitiallyExpanded: boolean;
initiallySelected: boolean;
showShareModal: ((chatSession: ChatSession) => void) | undefined;
showDeleteModal: ((chatSession: ChatSession) => void) | undefined;
}) => {
const [isExpanded, setIsExpanded] = useState<boolean>(isInitiallyExpanded);
const [isEditing, setIsEditing] = useState<boolean>(initiallySelected);
@@ -167,9 +161,6 @@ const FolderItem = ({
return a.time_created.localeCompare(b.time_created);
});
// Determine whether to show the trash can icon
const showTrashIcon = (isHovering && !isEditing) || showDeleteConfirm;
return (
<div
key={folder.folder_id}
@@ -215,60 +206,55 @@ const FolderItem = ({
{editedFolderName || folder.folder_name}
</div>
)}
<div className="flex ml-auto my-auto">
<div
onClick={handleEditFolderName}
className={`hover:bg-black/10 p-1 -m-1 rounded ${
isHovering && !isEditing
? ""
: "opacity-0 pointer-events-none"
}`}
>
<FiEdit2 size={16} />
</div>
<div className="relative">
<Popover
open={showDeleteConfirm}
onOpenChange={setShowDeleteConfirm}
content={
<div
onClick={handleDeleteClick}
className={`hover:bg-black/10 p-1 -m-1 rounded ml-2 ${
showTrashIcon ? "" : "opacity-0 pointer-events-none"
}`}
>
<FiTrash size={16} />
</div>
}
popover={
<div className="p-2 w-[225px] bg-background-100 rounded shadow-lg">
<p className="text-sm mb-2">
Are you sure you want to delete folder{" "}
<i>{folder.folder_name}</i>?
</p>
<div className="flex justify-end">
<button
onClick={confirmDelete}
className="bg-red-500 hover:bg-red-600 text-white px-2 py-1 rounded text-xs mr-2"
>
Yes
</button>
<button
onClick={cancelDelete}
className="bg-gray-300 hover:bg-gray-200 px-2 py-1 rounded text-xs"
>
No
</button>
{isHovering && !isEditing && (
<div className="flex ml-auto my-auto">
<div
onClick={handleEditFolderName}
className="hover:bg-black/10 p-1 -m-1 rounded"
>
<FiEdit2 size={16} />
</div>
<div className="relative">
<Popover
open={showDeleteConfirm}
onOpenChange={setShowDeleteConfirm}
content={
<div
onClick={handleDeleteClick}
className="hover:bg-black/10 p-1 -m-1 rounded ml-2"
>
<FiTrash size={16} />
</div>
</div>
}
side="top"
align="center"
/>
}
popover={
<div className="p-2 w-[225px] bg-background-100 rounded shadow-lg">
<p className="text-sm mb-2">
Are you sure you want to delete{" "}
<i>{folder.folder_name}</i>? All the content inside
this folder will also be deleted.
</p>
<div className="flex justify-end">
<button
onClick={confirmDelete}
className="bg-red-500 hover:bg-red-600 text-white px-2 py-1 rounded text-xs mr-2"
>
Yes
</button>
<button
onClick={cancelDelete}
className="bg-gray-300 hover:bg-gray-200 px-2 py-1 rounded text-xs"
>
No
</button>
</div>
</div>
}
side="top"
align="center"
/>
</div>
</div>
</div>
)}
{isEditing && (
<div className="flex ml-auto my-auto">
@@ -290,8 +276,6 @@ const FolderItem = ({
</div>
</div>
</BasicSelectable>
{/* Expanded Folder Content */}
{isExpanded && folders && (
<div className={"ml-2 pl-2 border-l border-border"}>
{folders.map((chatSession) => (
@@ -300,8 +284,6 @@ const FolderItem = ({
chatSession={chatSession}
isSelected={chatSession.id === currentChatId}
skipGradient={isDragOver}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}
/>
))}
</div>
@@ -315,15 +297,11 @@ export const FolderList = ({
currentChatId,
openedFolders,
newFolderId,
showShareModal,
showDeleteModal,
}: {
folders: Folder[];
currentChatId?: string;
openedFolders?: { [key: number]: boolean };
newFolderId: number | null;
showShareModal: ((chatSession: ChatSession) => void) | undefined;
showDeleteModal: ((chatSession: ChatSession) => void) | undefined;
}) => {
if (folders.length === 0) {
return null;
@@ -340,8 +318,6 @@ export const FolderList = ({
isInitiallyExpanded={
openedFolders ? openedFolders[folder.folder_id] || false : false
}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}
/>
))}
{folders.length == 1 && folders[0].chat_sessions.length == 0 && (

View File

@@ -76,7 +76,7 @@ export function AssistantsTab({
items={assistants.map((a) => a.id.toString())}
strategy={verticalListSortingStrategy}
>
<div className="px-4 pb-2 max-h-[500px] default-scrollbar overflow-y-scroll overflow-x-hidden my-3 grid grid-cols-1 gap-4">
<div className="px-4 pb-2 max-h-[500px] include-scrollbar overflow-y-scroll my-3 grid grid-cols-1 gap-4">
{assistants.map((assistant) => (
<DraggableAssistantCard
key={assistant.id.toString()}

View File

@@ -191,6 +191,7 @@ export function ChatSessionDisplay({
</div>
</CustomTooltip>
)}
<div>
{search ? (
showDeleteModal && (

View File

@@ -74,8 +74,6 @@ export function PagesTab({
folders={folders}
currentChatId={currentChatId}
openedFolders={openedFolders}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}
/>
</div>
)}

View File

@@ -260,29 +260,26 @@
}
}
.default-scrollbar::-webkit-scrollbar {
.include-scrollbar::-webkit-scrollbar {
width: 6px;
}
.default-scrollbar::-webkit-scrollbar-track {
.include-scrollbar::-webkit-scrollbar-track {
background: #f1f1f1;
}
.default-scrollbar::-webkit-scrollbar-thumb {
.include-scrollbar::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
.default-scrollbar::-webkit-scrollbar-thumb:hover {
.include-scrollbar::-webkit-scrollbar-thumb:hover {
background: #555;
}
.default-scrollbar {
.include-scrollbar {
scrollbar-width: thin;
scrollbar-color: #888 transparent;
overflow: overlay;
overflow-y: scroll;
overflow-x: hidden;
}
.inputscroll::-webkit-scrollbar-track {

View File

@@ -6,7 +6,6 @@ import {
} from "@/components/settings/lib";
import {
CUSTOM_ANALYTICS_ENABLED,
GTM_ENABLED,
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
} from "@/lib/constants";
import { Metadata } from "next";
@@ -22,7 +21,6 @@ import { getCurrentUserSS } from "@/lib/userSS";
import CardSection from "@/components/admin/CardSection";
import { Suspense } from "react";
import PostHogPageView from "./PostHogPageView";
import Script from "next/script";
const inter = Inter({
subsets: ["latin"],
@@ -82,22 +80,6 @@ export default async function RootLayout({
}}
/>
)}
{GTM_ENABLED && (
<Script
id="google-tag-manager"
strategy="afterInteractive"
dangerouslySetInnerHTML={{
__html: `
(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
})(window,document,'script','dataLayer','GTM-PZXS36NG');
`,
}}
/>
)}
</head>
<body className={`relative ${inter.variable} font-sans`}>
<div

View File

@@ -20,7 +20,7 @@ export function AdvancedOptionsToggle({
size="sm"
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
className="text-xs mr-auto !p-0 text-text-950 hover:text-text-500"
className="text-xs !p-0 text-text-950 hover:text-text-500"
>
{title || "Advanced Options"}
</Button>

View File

@@ -57,7 +57,7 @@ const DropdownOption: React.FC<DropdownOptionProps> = ({
};
export function UserDropdown({ page }: { page?: pageType }) {
const { user, isCurator } = useUser();
const { user } = useUser();
const [userInfoVisible, setUserInfoVisible] = useState(false);
const userInfoRef = useRef<HTMLDivElement>(null);
const router = useRouter();
@@ -95,9 +95,7 @@ export function UserDropdown({ page }: { page?: pageType }) {
}
// Construct the current URL
const currentUrl = `${pathname}${
searchParams.toString() ? `?${searchParams.toString()}` : ""
}`;
const currentUrl = `${pathname}${searchParams.toString() ? `?${searchParams.toString()}` : ""}`;
// Encode the current URL to use as a redirect parameter
const encodedRedirect = encodeURIComponent(currentUrl);
@@ -108,7 +106,9 @@ export function UserDropdown({ page }: { page?: pageType }) {
};
const showAdminPanel = !user || user.role === UserRole.ADMIN;
const showCuratorPanel = user && isCurator;
const showCuratorPanel =
user &&
(user.role === UserRole.CURATOR || user.role === UserRole.GLOBAL_CURATOR);
const showLogout =
user && !checkUserIsNoAuthUser(user.id) && !LOGOUT_DISABLED;
@@ -244,11 +244,7 @@ export function UserDropdown({ page }: { page?: pageType }) {
setShowNotifications(true);
}}
icon={<BellIcon className="h-5 w-5 my-auto mr-2" />}
label={`Notifications ${
notifications && notifications.length > 0
? `(${notifications.length})`
: ""
}`}
label={`Notifications ${notifications && notifications.length > 0 ? `(${notifications.length})` : ""}`}
/>
{showLogout &&

View File

@@ -211,7 +211,12 @@ export function TextFormField({
<div className={`w-full ${width}`}>
<div className="flex gap-x-2 items-center">
{!removeLabel && (
<Label className={sizeClass.label} small={small}>
<Label
className={`${
small ? "text-text-950" : "text-text-700 font-normal"
} ${sizeClass.label}`}
small={small}
>
{label}
</Label>
)}
@@ -657,10 +662,7 @@ export function SelectorFormField({
{container && (
<SelectContent
side={side}
className={`
${maxHeight ? `${maxHeight}` : "max-h-72"}
overflow-y-scroll
`}
className={maxHeight ? `max-h-[${maxHeight}]` : undefined}
container={container}
>
{options.length === 0 ? (

View File

@@ -47,7 +47,7 @@ export const AssistantsProvider: React.FC<{
const [assistants, setAssistants] = useState<Persona[]>(
initialAssistants || []
);
const { user, isLoadingUser, isAdmin, isCurator } = useUser();
const { user, isLoadingUser, isAdmin } = useUser();
const [editablePersonas, setEditablePersonas] = useState<Persona[]>([]);
const [allAssistants, setAllAssistants] = useState<Persona[]>([]);
@@ -83,7 +83,7 @@ export const AssistantsProvider: React.FC<{
useEffect(() => {
const fetchPersonas = async () => {
if (!isAdmin && !isCurator) {
if (!isAdmin) {
return;
}
@@ -101,8 +101,6 @@ export const AssistantsProvider: React.FC<{
if (allResponse.ok) {
const allPersonas = await allResponse.json();
setAllAssistants(allPersonas);
} else {
console.error("Error fetching personas:", allResponse);
}
} catch (error) {
console.error("Error fetching personas:", error);
@@ -110,7 +108,7 @@ export const AssistantsProvider: React.FC<{
};
fetchPersonas();
}, [isAdmin, isCurator]);
}, [isAdmin]);
const refreshRecentAssistants = async (currentAssistant: number) => {
const response = await fetch("/api/user/recent-assistants", {

View File

@@ -62,11 +62,7 @@ export const LlmList: React.FC<LlmListProps> = ({
return (
<div
className={`${
scrollable
? "max-h-[200px] default-scrollbar overflow-x-hidden"
: "max-h-[300px]"
} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
className={`${scrollable ? "max-h-[200px] include-scrollbar" : "max-h-[300px]"} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
>
{userDefault && (
<button

View File

@@ -18,7 +18,7 @@ export default function ExceptionTraceModal({
title="Full Exception Trace"
onOutsideClick={onOutsideClick}
>
<div className="overflow-y-auto default-scrollbar overflow-x-hidden pr-3 h-full mb-6">
<div className="overflow-y-auto include-scrollbar pr-3 h-full mb-6">
<div className="mb-6">
{!copyClicked ? (
<div

View File

@@ -169,8 +169,10 @@ export const SearchResultsDisplay = ({
{documents && documents.length > 0 && (
<div className="mt-4">
<div className="font-bold flex h-12 justify-between text-emphasis border-b mb-3 pb-1 border-border text-lg">
<div className="font-bold flex justify-between text-emphasis border-b mb-3 pb-1 border-border text-lg">
<p>Results</p>
<div className="h-8 w-0 invisible">invisibel text</div>
{!DISABLE_LLM_DOC_RELEVANCE &&
(contentEnriched || searchResponse.additional_relevance) && (
<TooltipProvider delayDuration={1000}>

View File

@@ -698,7 +698,7 @@ export const SearchSection = ({
</div>
</div>
<div className="absolute default-scrollbar h-screen overflow-y-auto overflow-x-hidden left-0 w-full top-0">
<div className="absolute left-0 w-full top-0">
<FunctionalHeader
sidebarToggled={toggledSidebar}
reset={() => setQuery("")}
@@ -728,6 +728,7 @@ export const SearchSection = ({
} pt-10 relative max-w-[2000px] xl:max-w-[1430px] mx-auto`}
>
<div className="absolute z-10 mobile:px-4 mobile:max-w-searchbar-max mobile:w-[90%] top-12 desktop:left-4 hidden 2xl:block mobile:left-1/2 mobile:transform mobile:-translate-x-1/2 desktop:w-52 3xl:w-64">
{/* Remove this entire SourceSelector block
{!settings?.isMobile &&
(ccPairs.length > 0 || documentSets.length > 0) && (
<SourceSelector
@@ -738,6 +739,7 @@ export const SearchSection = ({
availableTags={tags}
/>
)}
*/}
</div>
<div className="absolute left-0 hidden 2xl:block w-52 3xl:w-64"></div>
<div className="max-w-searchbar-max w-[90%] mx-auto">
@@ -762,16 +764,10 @@ export const SearchSection = ({
</div>
)}
<div
className={`mobile:fixed mobile:left-1/2 mobile:transform mobile:-translate-x-1/2 mobile:max-w-search-bar-max mobile:w-[90%] mobile:z-100 mobile:bottom-12`}
className={`mobile:max-w-search-bar-max mobile:w-[90%] mobile:z-100`}
>
<div
className={`transition-all duration-500 ease-in-out overflow-hidden
${
firstSearch
? "opacity-100 max-h-[500px]"
: "opacity-0 max-h-0"
}`}
onTransitionEnd={handleTransitionEnd}
className={`transition-all duration-500 ease-in-out overflow-hidden opacity-100 max-h-[500px]`}
>
<div className="mt-48 mb-8 flex justify-center items-center">
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message">
@@ -802,48 +798,50 @@ export const SearchSection = ({
setDefaultOverrides(SEARCH_DEFAULT_OVERRIDES_START);
await onSearch({ agentic, offset: 0 });
}}
finalAvailableDocumentSets={finalAvailableDocumentSets}
finalAvailableSources={finalAvailableSources}
finalAvailableDocumentSets={[]}
finalAvailableSources={[]}
filterManager={filterManager}
documentSets={documentSets}
ccPairs={ccPairs}
tags={tags}
documentSets={[]}
ccPairs={[]}
tags={[]}
/>
</div>
{!firstSearch && (
<SearchAnswer
isFetching={isFetching}
dedupedQuotes={dedupedQuotes}
searchResponse={searchResponse}
setSearchAnswerExpanded={setSearchAnswerExpanded}
searchAnswerExpanded={searchAnswerExpanded}
setCurrentFeedback={setCurrentFeedback}
searchState={searchState}
/>
)}
<div className="mt-6">
{!firstSearch && (
<SearchAnswer
isFetching={isFetching}
dedupedQuotes={dedupedQuotes}
searchResponse={searchResponse}
setSearchAnswerExpanded={setSearchAnswerExpanded}
searchAnswerExpanded={searchAnswerExpanded}
setCurrentFeedback={setCurrentFeedback}
searchState={searchState}
/>
)}
{!settings?.isMobile && (
<div className="mt-6">
{!(agenticResults && isFetching) || disabledAgentic ? (
<SearchResultsDisplay
searchState={searchState}
disabledAgentic={disabledAgentic}
contentEnriched={contentEnriched}
comments={comments}
sweep={sweep}
agenticResults={
shouldUseAgenticDisplay && !disabledAgentic
}
performSweep={performSweep}
searchResponse={searchResponse}
isFetching={isFetching}
defaultOverrides={defaultOverrides}
/>
) : (
<></>
)}
</div>
)}
{!settings?.isMobile && (
<div className="mt-6">
{!(agenticResults && isFetching) || disabledAgentic ? (
<SearchResultsDisplay
searchState={searchState}
disabledAgentic={disabledAgentic}
contentEnriched={contentEnriched}
comments={comments}
sweep={sweep}
agenticResults={
shouldUseAgenticDisplay && !disabledAgentic
}
performSweep={performSweep}
searchResponse={searchResponse}
isFetching={isFetching}
defaultOverrides={defaultOverrides}
/>
) : (
<></>
)}
</div>
)}
</div>
</div>
</div>
}

View File

@@ -67,10 +67,7 @@ export function UserProvider({
isLoadingUser,
refreshUser,
isAdmin: upToDateUser?.role === UserRole.ADMIN,
// Curator status applies for either global or basic curator
isCurator:
upToDateUser?.role === UserRole.CURATOR ||
upToDateUser?.role === UserRole.GLOBAL_CURATOR,
isCurator: upToDateUser?.role === UserRole.CURATOR,
isCloudSuperuser: upToDateUser?.is_cloud_superuser ?? false,
}}
>

View File

@@ -22,8 +22,7 @@ export async function fetchAssistantData(): Promise<AssistantData> {
// Fetch core assistants data first
const [assistants, assistantsFetchError] = await fetchAssistantsSS();
if (assistantsFetchError) {
// This is not a critical error and occurs when the user is not logged in
console.warn(`Failed to fetch assistants - ${assistantsFetchError}`);
console.error(`Failed to fetch assistants - ${assistantsFetchError}`);
return defaultState;
}

View File

@@ -36,10 +36,8 @@ export const SIDEBAR_WIDTH = `w-[350px]`;
export const LOGOUT_DISABLED =
process.env.NEXT_PUBLIC_DISABLE_LOGOUT?.toLowerCase() === "true";
// Default sidebar open is true if the environment variable is not set
export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true" ??
true;
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true";
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";
@@ -62,9 +60,6 @@ export const CUSTOM_ANALYTICS_ENABLED = process.env.CUSTOM_ANALYTICS_SECRET_KEY
? true
: false;
export const GTM_ENABLED =
process.env.NEXT_PUBLIC_GTM_ENABLED?.toLowerCase() === "true";
export const DISABLE_LLM_DOC_RELEVANCE =
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";

View File

@@ -174,6 +174,7 @@ export function useLlmOverride(
modelName: "",
}
);
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
currentChatSession && currentChatSession.current_alternate_model
? destructureValue(currentChatSession.current_alternate_model)

View File

@@ -1,14 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Connectors - Add Connector",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/add-connector");
await expect(page.locator("h1.text-3xl")).toHaveText("Add Connector");
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(/^Storage/);
}
);

View File

@@ -1,19 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - User Management - API Keys",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/api-key");
await expect(page.locator("h1.text-3xl")).toHaveText("API Keys");
await expect(page.locator("p.text-sm")).toHaveText(
/^API Keys allow you to access Danswer APIs programmatically/
);
await expect(
page.getByRole("button", { name: "Create API Key" })
).toHaveCount(1);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Custom Assistants - Assistants",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/assistants");
await expect(page.locator("h1.text-3xl")).toHaveText("Assistants");
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
/^Assistants are a way to build/
);
}
);

View File

@@ -1,24 +0,0 @@
// dependency for all admin user tests
import { test as setup, expect } from "@playwright/test";
import { TEST_CREDENTIALS } from "./constants";
setup("authenticate", async ({ page }) => {
const { email, password } = TEST_CREDENTIALS;
await page.goto("http://localhost:3000/search");
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fsearch");
await expect(page).toHaveTitle("Danswer");
await page.fill("#email", email);
await page.fill("#password", password);
// Click the login button
await page.click('button[type="submit"]');
await page.waitForURL("http://localhost:3000/search");
await page.context().storageState({ path: "admin_auth.json" });
});

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Custom Assistants - Slack Bots",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/bots");
await expect(page.locator("h1.text-3xl")).toHaveText("Slack Bots");
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
/^Setup Slack bots that connect to Danswer./
);
}
);

View File

@@ -1,18 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Configuration - Document Processing",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto(
"http://localhost:3000/admin/configuration/document-processing"
);
await expect(page.locator("h1.text-3xl")).toHaveText("Document Processing");
await expect(page.locator("h3.text-2xl")).toHaveText(
"Process with Unstructured API"
);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Configuration - LLM",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/configuration/llm");
await expect(page.locator("h1.text-3xl")).toHaveText("LLM Setup");
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(
"Enabled LLM Providers"
);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Configuration - Search Settings",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/configuration/search");
await expect(page.locator("h1.text-3xl")).toHaveText("Search Settings");
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(
"Embedding Model"
);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Document Management - Feedback",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/documents/explorer");
await expect(page.locator("h1.text-3xl")).toHaveText("Document Explorer");
await expect(page.locator("div.flex.text-emphasis.mt-3")).toHaveText(
"Search for a document above to modify its boost or hide it from searches."
);
}
);

View File

@@ -1,19 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Document Management - Feedback",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/documents/feedback");
await expect(page.locator("h1.text-3xl")).toHaveText("Document Feedback");
await expect(page.locator("h1.text-lg").nth(0)).toHaveText(
"Most Liked Documents"
);
await expect(page.locator("h1.text-lg").nth(1)).toHaveText(
"Most Disliked Documents"
);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Document Management - Document Sets",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/documents/sets");
await expect(page.locator("h1.text-3xl")).toHaveText("Document Sets");
await expect(page.locator("p.text-sm")).toHaveText(
/^Document Sets allow you to group logically connected documents into a single bundle./
);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - User Management - Groups",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/groups");
await expect(page.locator("h1.text-3xl")).toHaveText("Manage User Groups");
await expect(
page.getByRole("button", { name: "Create New User Group" })
).toHaveCount(1);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Connectors - Existing Connectors",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/indexing/status");
await expect(page.locator("h1.text-3xl")).toHaveText("Existing Connectors");
await expect(page.locator("p.text-sm")).toHaveText(
/^It looks like you don't have any connectors setup yet./
);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Performance - Custom Analytics",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/performance/custom-analytics");
await expect(page.locator("h1.text-3xl")).toHaveText("Custom Analytics");
await expect(page.locator("div.font-medium").nth(0)).toHaveText(
"Custom Analytics is not enabled."
);
}
);

View File

@@ -1,22 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test.describe("Admin Performance Query History", () => {
// Ignores the diff for elements targeted by the specified list of selectors
// exclude button since they change based on the date
test.use({ ignoreSelectors: ["button"] });
test(
"Admin - Performance - Query History",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/performance/query-history");
await expect(page.locator("h1.text-3xl")).toHaveText("Query History");
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
"Feedback Type"
);
}
);
});

View File

@@ -1,20 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test.describe("Admin Performance Usage", () => {
// Ignores the diff for elements targeted by the specified list of selectors
// exclude button and svg since they change based on the date
test.use({ ignoreSelectors: ["button", "svg"] });
test(
"Admin - Performance - Usage Statistics",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto("http://localhost:3000/admin/performance/usage");
await expect(page.locator("h1.text-3xl")).toHaveText("Usage Statistics");
await expect(page.locator("h1.text-lg").nth(0)).toHaveText("Usage");
await expect(page.locator("h1.text-lg").nth(1)).toHaveText("Feedback");
}
);
});

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Custom Assistants - Prompt Library",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/prompt-library");
await expect(page.locator("h1.text-3xl")).toHaveText("Prompt Library");
await expect(page.locator("p.text-sm")).toHaveText(
/^Create prompts that can be accessed/
);
}
);

View File

@@ -1,19 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Settings - Workspace Settings",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/settings");
await expect(page.locator("h1.text-3xl")).toHaveText("Workspace Settings");
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
/^Manage general Danswer settings applicable to all users in the workspace./
);
await expect(
page.getByRole("button", { name: "Set Retention Limit" })
).toHaveCount(1);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Custom Assistants - Standard Answers",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/standard-answer");
await expect(page.locator("h1.text-3xl")).toHaveText("Standard Answers");
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
/^Manage the standard answers for pre-defined questions./
);
}
);

View File

@@ -1,22 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - User Management - Token Rate Limits",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/token-rate-limits");
await expect(page.locator("h1.text-3xl")).toHaveText("Token Rate Limits");
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
/^Token rate limits enable you control how many tokens can be spent in a given time period./
);
await expect(
page.getByRole("button", { name: "Create a Token Rate Limit" })
).toHaveCount(1);
await expect(page.locator("h1.text-lg")).toHaveText(
"Global Token Rate Limits"
);
}
);

View File

@@ -1,16 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Custom Assistants - Tools",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/tools");
await expect(page.locator("h1.text-3xl")).toHaveText("Tools");
await expect(page.locator("p.text-sm")).toHaveText(
"Tools allow assistants to retrieve information or take actions."
);
}
);

View File

@@ -1,19 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - User Management - Groups",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/users");
await expect(page.locator("h1.text-3xl")).toHaveText("Manage Users");
await expect(page.locator("div.font-bold").nth(0)).toHaveText(
"Invited Users"
);
await expect(page.locator("div.font-bold").nth(1)).toHaveText(
"Current Users"
);
}
);

View File

@@ -1,18 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - Performance - Whitelabeling",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/whitelabeling");
await expect(page.locator("h1.text-3xl")).toHaveText("Whitelabeling");
await expect(page.locator("div.block").nth(0)).toHaveText(
"Application Name"
);
await expect(page.locator("div.block").nth(1)).toHaveText("Custom Logo");
await expect(page.getByRole("button", { name: "Update" })).toHaveCount(1);
}
);

View File

@@ -1,19 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Chat",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/chat");
await expect(page.locator("div.text-2xl").nth(0)).toHaveText("General");
await expect(page.getByRole("button", { name: "Search S" })).toHaveClass(
/text-text-application-untoggled/
);
await expect(page.getByRole("button", { name: "Chat D" })).toHaveClass(
/text-text-application-toggled/
);
}
);

View File

@@ -1,5 +0,0 @@
// constants.js
export const TEST_CREDENTIALS = {
email: "admin_user@test.com",
password: "test",
};

View File

@@ -1,31 +1,27 @@
// Add this line
import { test, expect, takeSnapshot } from "@chromatic-com/playwright";
import { TEST_CREDENTIALS } from "./constants";
// Then use as normal 👇
test(
"Homepage",
{
tag: "@guest",
},
async ({ page }, testInfo) => {
// Test redirect to login, and redirect to search after login
const { email, password } = TEST_CREDENTIALS;
test("Homepage", async ({ page }, testInfo) => {
// Test redirect to login, and redirect to search after login
await page.goto("http://localhost:3000/search");
// move these into a constants file or test fixture soon
let email = "admin_user@test.com";
let password = "test";
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fsearch");
await page.goto("http://localhost:3000/search");
await expect(page).toHaveTitle("Danswer");
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fsearch");
await takeSnapshot(page, "Before login", testInfo);
await expect(page).toHaveTitle("Danswer");
await page.fill("#email", email);
await page.fill("#password", password);
await takeSnapshot(page, "Before login", testInfo);
// Click the login button
await page.click('button[type="submit"]');
await page.fill("#email", email);
await page.fill("#password", password);
await page.waitForURL("http://localhost:3000/search");
}
);
// Click the login button
await page.click('button[type="submit"]');
await page.waitForURL("http://localhost:3000/search");
});

View File

@@ -1,19 +0,0 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Search",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/search");
await expect(page.locator("div.text-3xl")).toHaveText("Unlock Knowledge");
await expect(page.getByRole("button", { name: "Search S" })).toHaveClass(
/text-text-application-toggled/
);
await expect(page.getByRole("button", { name: "Chat D" })).toHaveClass(
/text-text-application-untoggled/
);
}
);