Compare commits

...

28 Commits

Author SHA1 Message Date
pablodanswer
f940e6a12d nit 2024-10-20 23:10:21 -07:00
pablodanswer
67ae58ac06 nit 2024-10-20 23:09:04 -07:00
pablodanswer
a6b7159c17 k: 2024-10-20 22:54:19 -07:00
pablodanswer
c62860fa44 k 2024-10-20 18:16:49 -07:00
pablodanswer
83f83cb3eb kk 2024-10-20 17:30:04 -07:00
pablodanswer
0599a6cf65 initial bones 2024-10-20 15:30:45 -07:00
pablodanswer
8f67f1715c minor typing 2024-10-20 14:48:19 -07:00
pablodanswer
3b365509e2 k 2024-10-20 14:41:12 -07:00
pablodanswer
022cbdfccf robustified cloud auth type 2024-10-20 14:28:22 -07:00
pablodanswer
ebec6f6b10 k 2024-10-20 13:43:08 -07:00
pablodanswer
1cad9c7b3d add cloud auth type 2024-10-20 13:43:08 -07:00
pablodanswer
b4e975013c k 2024-10-20 13:42:38 -07:00
pablodanswer
dd26f92206 nit 2024-10-20 13:41:41 -07:00
pablodanswer
4d00ec45ad remove comments + notice logs 2024-10-20 13:34:13 -07:00
pablodanswer
1a81c67a67 k 2024-10-20 13:22:00 -07:00
pablodanswer
04f965e656 k 2024-10-20 11:52:24 -07:00
pablodanswer
277d37e0ee fix 2024-10-20 11:45:00 -07:00
pablodanswer
3cd260131b k 2024-10-20 10:16:19 -07:00
pablodanswer
ad21ee0e9a fix mysterious syncing issue! 2024-10-19 19:26:57 -07:00
pablodanswer
c7dc0e9af0 k 2024-10-19 19:15:55 -07:00
pablodanswer
75c5de802b ensure tenant id passed 2024-10-19 19:15:55 -07:00
pablodanswer
c39f590d0d k 2024-10-19 19:15:55 -07:00
pablodanswer
82a9fda846 add types 2024-10-19 19:15:55 -07:00
pablodanswer
842d4ab2a8 k 2024-10-19 19:15:55 -07:00
pablodanswer
cddcec4ea4 k 2024-10-19 19:15:55 -07:00
pablodanswer
09dd7b424c validated workaround for flush + reset 2024-10-19 19:15:55 -07:00
pablodanswer
a2fd8d5e0a add some more multi tenancy 2024-10-19 19:15:55 -07:00
pablodanswer
802dc00f78 k 2024-10-19 19:15:55 -07:00
37 changed files with 957 additions and 291 deletions

View File

@@ -130,7 +130,7 @@ jobs:
done
echo "Finished waiting for service."
- name: Run integration tests
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
@@ -145,7 +145,8 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test
danswer/danswer-integration:test \
/app/tests/integration/tests
continue-on-error: true
id: run_tests
@@ -158,6 +159,56 @@ jobs:
echo "All integration tests passed successfully."
fi
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# Start containers for multi-tenant tests
- name: Start Docker containers for multi-tenant tests
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=basic \ # Adjust as needed
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
danswer/danswer-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
id: run_tests
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Save Docker logs
if: success() || failure()
run: |

View File

@@ -58,6 +58,7 @@ from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_VERIFICATION
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
@@ -132,7 +133,9 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
# all other auth types besides basic should require users to be
# verified
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
return not DISABLE_VERIFICATION and (
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
)
def verify_email_is_invited(email: str) -> None:
@@ -233,35 +236,60 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
if hasattr(user_create, "role"):
user_count = await get_user_count()
if (
user_count == 0
or user_create.email in get_default_admin_user_emails()
):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
return user
async def on_after_login(
self,
@@ -319,7 +347,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
self.database = tenant_user_db # type: ignore
oauth_account_dict = {
"oauth_name": oauth_name,

View File

@@ -4,7 +4,6 @@ import time
from datetime import timedelta
from typing import Any
import redis
import sentry_sdk
from celery import bootsteps # type: ignore
from celery import Celery
@@ -79,6 +78,7 @@ def on_task_prerun(
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
tenant_id: str | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
@@ -91,7 +91,7 @@ def on_task_postrun(
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
@@ -110,7 +110,17 @@ def on_task_postrun(
if not task:
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
if state not in READY_STATES:
return
@@ -118,7 +128,7 @@ def on_task_postrun(
if not task_id:
return
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
@@ -171,7 +181,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)
# (set at the command line)'
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
@@ -182,166 +193,155 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
tenant_ids = get_all_tenant_ids()
# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
for tenant_id in tenant_ids:
# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
if search_settings.provider_type is None:
logger.notice(
"Running a first inference to warm up embedding model"
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
r = get_redis_client()
if not hasattr(sender, "primary_worker_locks"):
sender.primary_worker_locks = {}
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
tenant_ids = get_all_tenant_ids()
if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for primary worker to be ready...")
time_start = time.monotonic()
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time.monotonic()
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.notice("Redis: Readiness check starting.")
while True:
# Log all the locks in Redis
all_locks = r.keys("*")
logger.notice(f"Current Redis locks: {all_locks}")
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
logger.error(msg)
raise WorkerShutdown(msg)
if time_elapsed > WAIT_LIMIT:
msg = (
"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return # Exit the function for secondary workers
time.sleep(WAIT_INTERVAL)
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
logger.info("Wait for primary worker completed successfully. Continuing...")
return
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as the primary celery worker.")
time_start = time.monotonic()
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client()
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_lock = lock
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
@@ -367,14 +367,15 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
if lock.owned():
lock.release()
sender.primary_worker_lock = None
for tenant_id, lock in sender.primary_worker_locks.items():
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
if lock.owned():
lock.release()
sender.primary_worker_locks = {}
class CeleryTaskPlainFormatter(PlainFormatter):
@@ -449,17 +450,18 @@ def on_setup_logging(
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# it's unclear to me whether using the hub's timer or the bootstep timer is better
# Requires the Hub component
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
@@ -478,42 +480,58 @@ class HubPeriodicTask(bootsteps.StartStopStep):
def run_periodic_task(self, worker: Any) -> None:
try:
if not worker.primary_worker_lock:
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
r = get_redis_client()
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
lock: redis.lock.Lock = worker.primary_worker_lock
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be computer sleep or a clock change."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
worker.primary_worker_lock = lock
except Exception:
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops
@@ -583,14 +601,14 @@ tasks_to_schedule = [
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"args": (tenant_id,), # Must pass tenant_id as an argument
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules

View File

@@ -31,7 +31,10 @@ logger = setup_logger()
def _get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -44,7 +47,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
@@ -54,9 +57,14 @@ def _get_deletion_status(
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None

View File

@@ -23,8 +23,8 @@ from danswer.redis.redis_pool import get_redis_client
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
r = get_redis_client()
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,

View File

@@ -51,10 +51,10 @@ logger = setup_logger()
name="check_for_indexing",
soft_time_limit=300,
)
def check_for_indexing(tenant_id: str | None) -> int | None:
def check_for_indexing(*, tenant_id: str | None) -> int | None:
tasks_created = 0
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
@@ -64,7 +64,10 @@ def check_for_indexing(tenant_id: str | None) -> int | None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None
else:
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
@@ -367,7 +370,7 @@ def connector_indexing_task(
attempt = None
n_final_progress = 0
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)

View File

@@ -38,8 +38,8 @@ logger = setup_logger()
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_pruning(tenant_id: str | None) -> None:
r = get_redis_client()
def check_for_pruning(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -204,7 +204,7 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)

View File

@@ -59,6 +59,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@@ -66,6 +67,8 @@ from danswer.utils.variable_functionality import (
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@@ -74,11 +77,11 @@ from danswer.utils.variable_functionality import noop_fallback
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task(tenant_id: str | None) -> None:
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -640,7 +643,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False
"""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,

View File

@@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Necessary for cloud integration tests
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
# information. This provides an extra layer of security on top of Postgres access controls
# and is available in Danswer EE

View File

@@ -160,6 +160,9 @@ class AuthType(str, Enum):
OIDC = "oidc"
SAML = "saml"
# google auth and basic
CLOUD = "cloud"
class SessionType(str, Enum):
CHAT = "Chat"

View File

@@ -41,7 +41,7 @@ class ConfluenceRateLimitError(Exception):
# # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available
# r = get_redis_client()
# r = get_redis_client(tenant_id=tenant_id)
# for attempt in range(max_retries):
# try:

View File

@@ -241,8 +241,7 @@ def create_credential(
curator_public=credential_data.curator_public,
)
db_session.add(credential)
db_session.flush() # This ensures the credential gets an ID
db_session.flush() # This ensures the credential gets an IDcredentials
_relate_credential_to_user_groups__no_commit(
db_session=db_session,
credential_id=credential.id,

View File

@@ -295,30 +295,32 @@ async def get_async_session_with_tenant(
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
if tenant_id is None:
tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection without starting a transaction
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Execute SET search_path outside of any transaction
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path TO "{tenant_id}"')
# Optionally verify the search_path was set correctly
cursor.execute("SHOW search_path")
cursor.fetchone()
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Proceed to create a session using the connection
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
@@ -332,6 +334,18 @@ def get_session_with_tenant(
cursor.close()
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
logger.debug(
f"Set search_path to {tenant_id} for connection {connection_record}"
)
def get_session_generator_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:

View File

@@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
self.redis_client = get_redis_client()
tenant_id = current_tenant_id.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager
def get_session(self) -> Iterator[Session]:

View File

@@ -269,7 +269,7 @@ def get_application() -> FastAPI:
# Server logs this during auth setup verification step
pass
elif AUTH_TYPE == AuthType.BASIC:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_auth_router(auth_backend),
@@ -301,7 +301,7 @@ def get_application() -> FastAPI:
tags=["users"],
)
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,

View File

@@ -1,4 +1,7 @@
import functools
import threading
from collections.abc import Callable
from typing import Any
from typing import Optional
import redis
@@ -14,6 +17,72 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
from danswer.utils.logger import setup_logger
logger = setup_logger()
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.tenant_id: str = tenant_id
def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview:
prefix: str = f"{self.tenant_id}:"
if isinstance(key, str):
if key.startswith(prefix):
return key
else:
return prefix + key
elif isinstance(key, bytes):
prefix_bytes = prefix.encode()
if key.startswith(prefix_bytes):
return key
else:
return prefix_bytes + key
elif isinstance(key, memoryview):
key_bytes = key.tobytes()
prefix_bytes = prefix.encode()
if key_bytes.startswith(prefix_bytes):
return key
else:
return memoryview(prefix_bytes + key_bytes)
else:
raise TypeError(f"Unsupported key type: {type(key)}")
def _prefix_method(self, method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if "name" in kwargs:
kwargs["name"] = self._prefixed(kwargs["name"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
return method(*args, **kwargs)
return wrapper
def __getattribute__(self, item: str) -> Any:
original_attr = super().__getattribute__(item)
methods_to_wrap = [
"lock",
"unlock",
"get",
"set",
"delete",
"exists",
"incrby",
"hset",
"hget",
"getset",
"scan_iter",
"owned",
"reacquire",
"create_lock",
"startswith",
] # Add all methods that need prefixing
if item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr
class RedisPool:
@@ -32,8 +101,10 @@ class RedisPool:
def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def get_client(self) -> Redis:
return redis.Redis(connection_pool=self._pool)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._pool)
@staticmethod
def create_pool(
@@ -84,8 +155,8 @@ class RedisPool:
redis_pool = RedisPool()
def get_redis_client() -> Redis:
return redis_pool.get_client()
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
# # Usage example

View File

@@ -9,6 +9,7 @@ from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import get_documents_by_cc_pair
from danswer.db.document import get_ingestion_documents
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
@@ -67,6 +68,7 @@ def upsert_ingestion_doc(
doc_info: IngestionDocument,
_: User | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
tenant_id=Depends(get_current_tenant_id),
) -> IngestionResult:
doc_info.document.from_ingestion_api = True
@@ -101,6 +103,7 @@ def upsert_ingestion_doc(
document_index=curr_doc_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
)
new_doc, __chunk_count = indexing_pipeline(
@@ -134,6 +137,7 @@ def upsert_ingestion_doc(
document_index=sec_doc_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
)
sec_ind_pipeline(

View File

@@ -24,6 +24,7 @@ from danswer.db.connector_credential_pair import (
)
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@@ -90,6 +91,7 @@ def get_cc_pair_full_info(
cc_pair_id: int,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CCPairFullInfo:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session, user, get_editable=False
@@ -136,6 +138,7 @@ def get_cc_pair_full_info(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
tenant_id=tenant_id,
),
num_docs_indexed=documents_indexed,
is_editable_for_current_user=is_editable_for_current_user,
@@ -231,6 +234,7 @@ def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers pruning on a particular cc_pair immediately"""
@@ -246,7 +250,7 @@ def prune_cc_pair(
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(db_session, r):
raise HTTPException(

View File

@@ -482,10 +482,11 @@ def get_connector_indexing_status(
get_editable: bool = Query(
False, description="If true, return editable document sets"
),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = []
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like
@@ -606,6 +607,7 @@ def get_connector_indexing_status(
connector_id=connector.id,
credential_id=credential.id,
db_session=db_session,
tenant_id=tenant_id,
),
is_deletable=check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair,
@@ -683,15 +685,18 @@ def create_connector_with_mock_credential(
connector_response = create_connector(
db_session=db_session, connector_data=connector_data
)
mock_credential = CredentialBase(
credential_json={}, admin_public=True, source=connector_data.source
)
credential = create_credential(
mock_credential, user=user, db_session=db_session
)
access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
)
response = add_credential_to_connector(
db_session=db_session,
user=user,
@@ -775,7 +780,7 @@ def connector_run_once(
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids

View File

@@ -42,7 +42,9 @@ global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
def sync_external_doc_permissions_task(
cc_pair_id: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@@ -50,7 +52,7 @@ def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None
cc_pair_id: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@@ -59,7 +61,7 @@ def sync_external_group_permissions_task(
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(
retention_limit_days: int, tenant_id: str | None
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -72,7 +74,7 @@ def perform_ttl_management_task(
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external permissions"""
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
@@ -89,7 +91,7 @@ def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external group permissions"""
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
@@ -106,7 +108,7 @@ def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(tenant_id: str | None) -> None:
def check_ttl_management_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
@@ -130,7 +132,7 @@ def check_ttl_management_task(tenant_id: str | None) -> None:
name="autogenerate_usage_report_task",
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(
@@ -179,7 +181,7 @@ for tenant_id in tenant_ids:
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"args": (tenant_id,), # Must pass tenant_id as an argument
"kwargs": {"tenant_id": tenant_id}, # Pass tenant_id as a keyword argument
}
# Include any existing beat schedules

View File

@@ -22,8 +22,9 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
) -> Response:
try:
logger.info(f"Request route: {request.url.path}")
logger.info(f"Request cookies: {request.cookies}")
if not MULTI_TENANT:
logger.info("SETITNG TO DEFAULt")
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
token = request.cookies.get("tenant_details")

View File

@@ -83,4 +83,5 @@ COPY ./tests/integration /app/tests/integration
ENV PYTHONPATH=/app
CMD ["pytest", "-s", "/app/tests/integration"]
ENTRYPOINT ["pytest", "-s"]
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]

View File

@@ -23,7 +23,7 @@ from tests.integration.common_utils.test_models import StreamedResponse
class ChatSessionManager:
@staticmethod
def create(
persona_id: int = -1,
persona_id: int = 0,
description: str = "Test chat session",
user_performing_action: DATestUser | None = None,
) -> DATestChatSession:

View File

@@ -32,6 +32,7 @@ class CredentialManager:
"curator_public": curator_public,
"groups": groups or [],
}
response = requests.post(
url=f"{API_SERVER_URL}/manage/credential",
json=credential_request,

View File

@@ -0,0 +1,82 @@
from datetime import datetime
from datetime import timedelta
import jwt
import requests
from danswer.server.manage.models import AllUsersResponse
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
def generate_auth_token() -> str:
payload = {
"iss": "control_plane",
"exp": datetime.utcnow() + timedelta(minutes=5),
"iat": datetime.utcnow(),
"scope": "tenant:create",
}
token = jwt.encode(payload, "", algorithm="HS256")
return token
class TenantManager:
@staticmethod
def create(
tenant_id: str | None = None,
initial_admin_email: str | None = None,
) -> dict[str, str]:
body = {
"tenant_id": tenant_id,
"initial_admin_email": initial_admin_email,
}
token = generate_auth_token()
headers = {
"Authorization": f"Bearer {token}",
"X-API-KEY": "",
"Content-Type": "application/json",
}
response = requests.post(
url=f"{API_SERVER_URL}/tenants/create",
json=body,
headers=headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_all_users(
user_performing_action: DATestUser | None = None,
) -> AllUsersResponse:
response = requests.get(
url=f"{API_SERVER_URL}/manage/users",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return AllUsersResponse(
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
accepted_pages=data["accepted_pages"],
invited_pages=data["invited_pages"],
)
@staticmethod
def verify_user_in_tenant(
user: DATestUser, user_performing_action: DATestUser | None = None
) -> None:
all_users = TenantManager.get_all_users(user_performing_action)
for accepted_user in all_users.accepted:
if accepted_user.email == user.email and accepted_user.id == user.id:
return
raise ValueError(f"User {user.email} not found in tenant")

View File

@@ -65,15 +65,23 @@ class UserManager:
data=data,
headers=headers,
)
response.raise_for_status()
result_cookie = next(iter(response.cookies), None)
if not result_cookie:
response.raise_for_status()
cookies = response.cookies.get_dict()
session_cookie = cookies.get("fastapiusersauth")
tenant_details_cookie = cookies.get("tenant_details")
if not session_cookie:
raise Exception("Failed to login")
print(f"Logged in as {test_user.email}")
cookie = f"{result_cookie.name}={result_cookie.value}"
test_user.headers["Cookie"] = cookie
# Set both cookies in the headers
test_user.headers["Cookie"] = (
f"fastapiusersauth={session_cookie}; "
f"tenant_details={tenant_details_cookie}"
)
return test_user
@staticmethod

View File

@@ -1,17 +1,20 @@
import logging
import time
from types import SimpleNamespace
import psycopg2
import requests
from alembic import command
from alembic.config import Config
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SYNC_DB_API
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
@@ -26,7 +29,11 @@ logger = setup_logger()
def _run_migrations(
database_url: str, direction: str = "upgrade", revision: str = "head"
database_url: str,
config_name: str,
direction: str = "upgrade",
revision: str = "head",
schema: str = "public",
) -> None:
# hide info logs emitted during migration
logging.getLogger("alembic").setLevel(logging.CRITICAL)
@@ -35,6 +42,10 @@ def _run_migrations(
alembic_cfg = Config("alembic.ini")
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
alembic_cfg.attributes["configure_logger"] = False
alembic_cfg.config_ini_section = config_name
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
# Set the SQLAlchemy URL in the Alembic configuration
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
@@ -52,7 +63,9 @@ def _run_migrations(
logging.getLogger("alembic").setLevel(logging.INFO)
def reset_postgres(database: str = "postgres") -> None:
def reset_postgres(
database: str = "postgres", config_name: str = "alembic", setup_danswer: bool = True
) -> None:
"""Reset the Postgres database."""
# NOTE: need to delete all rows to allow migrations to be rolled back
@@ -111,14 +124,18 @@ def reset_postgres(database: str = "postgres") -> None:
)
_run_migrations(
conn_str,
config_name,
direction="downgrade",
revision="base",
)
_run_migrations(
conn_str,
config_name,
direction="upgrade",
revision="head",
)
if setup_danswer:
return
# do the same thing as we do on API server startup
with get_session_context_manager() as db_session:
@@ -127,6 +144,7 @@ def reset_postgres(database: str = "postgres") -> None:
def reset_vespa() -> None:
"""Wipe all data from the Vespa index."""
with get_session_context_manager() as db_session:
# swap to the correct default model
check_index_swap(db_session)
@@ -166,10 +184,98 @@ def reset_vespa() -> None:
time.sleep(5)
def reset_postgres_multitenant() -> None:
"""Reset the Postgres database for all tenants in a multitenant setup."""
conn = psycopg2.connect(
dbname="postgres",
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
conn.autocommit = True
cur = conn.cursor()
# Get all tenant schemas
cur.execute(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name LIKE 'tenant_%'
"""
)
tenant_schemas = cur.fetchall()
# Drop all tenant schemas
for schema in tenant_schemas:
schema_name = schema[0]
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
cur.close()
conn.close()
reset_postgres(config_name="schema_private")
def reset_vespa_multitenant() -> None:
"""Wipe all data from the Vespa index for all tenants."""
for tenant_id in get_all_tenant_ids():
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# swap to the correct default model for each tenant
check_index_swap(db_session)
search_settings = get_current_search_settings(db_session)
index_name = search_settings.index_name
success = setup_vespa(
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError(
f"Could not connect to Vespa for tenant {tenant_id} within the specified timeout."
)
for _ in range(5):
try:
continuation = None
should_continue = True
while should_continue:
params = {"selection": "true", "cluster": "danswer_index"}
if continuation:
params = {**params, "continuation": continuation}
response = requests.delete(
DOCUMENT_ID_ENDPOINT.format(index_name=index_name),
params=params,
)
response.raise_for_status()
response_json = response.json()
continuation = response_json.get("continuation")
should_continue = bool(continuation)
break
except Exception as e:
print(f"Error deleting documents for tenant {tenant_id}: {e}")
time.sleep(5)
def reset_all() -> None:
"""Reset both Postgres and Vespa."""
logger.info("Resetting Postgres...")
reset_postgres()
logger.info("Resetting Vespa...")
reset_vespa()
def reset_all_multitenant() -> None:
"""Reset both Postgres and Vespa for all tenants."""
logger.info("Resetting Postgres for all tenants...")
reset_postgres_multitenant()
logger.info("Resetting Vespa for all tenants...")
reset_vespa_multitenant()
logger.info("Finished resetting all.")

View File

@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
from tests.integration.common_utils.vespa import vespa_fixture
@@ -44,3 +45,8 @@ def vespa_client(db_session: Session) -> vespa_fixture:
@pytest.fixture
def reset() -> None:
reset_all()
@pytest.fixture
def reset_multitenant() -> None:
reset_all_multitenant()

View File

@@ -0,0 +1,154 @@
from danswer.db.models import UserRole
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.tenant import TenantManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create Tenant 1 and its Admin User
TenantManager.create("tenant_dev1", "test1@test.com")
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
assert UserManager.verify_role(test_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
TenantManager.create("tenant_dev2", "test2@test.com")
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
assert UserManager.verify_role(test_user2, UserRole.ADMIN)
# Create connectors for Tenant 1
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=test_user1,
)
api_key_1: DATestAPIKey = APIKeyManager.create(
user_performing_action=test_user1,
)
api_key_1.headers.update(test_user1.headers)
LLMProviderManager.create(user_performing_action=test_user1)
# Seed documents for Tenant 1
cc_pair_1.documents = []
doc1_tenant1 = DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Tenant 1 Document Content",
api_key=api_key_1,
)
doc2_tenant1 = DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Tenant 1 Document Content",
api_key=api_key_1,
)
cc_pair_1.documents.extend([doc1_tenant1, doc2_tenant1])
# Create connectors for Tenant 2
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=test_user2,
)
api_key_2: DATestAPIKey = APIKeyManager.create(
user_performing_action=test_user2,
)
api_key_2.headers.update(test_user2.headers)
LLMProviderManager.create(user_performing_action=test_user2)
# Seed documents for Tenant 2
cc_pair_2.documents = []
doc1_tenant2 = DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_2,
content="Tenant 2 Document Content",
api_key=api_key_2,
)
doc2_tenant2 = DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_2,
content="Tenant 2 Document Content",
api_key=api_key_2,
)
cc_pair_2.documents.extend([doc1_tenant2, doc2_tenant2])
tenant1_doc_ids = {doc1_tenant1.id, doc2_tenant1.id}
tenant2_doc_ids = {doc1_tenant2.id, doc2_tenant2.id}
# Create chat sessions for each user
chat_session1: DATestChatSession = ChatSessionManager.create(
user_performing_action=test_user1
)
chat_session2: DATestChatSession = ChatSessionManager.create(
user_performing_action=test_user2
)
# User 1 sends a message and gets a response
response1 = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
message="What is in Tenant 1's documents?",
user_performing_action=test_user1,
)
# Assert that the search tool was used
assert response1.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response1.tool_result}
assert tenant1_doc_ids.issubset(
response_doc_ids
), "Not all Tenant 1 document IDs are in the response"
assert not response_doc_ids.intersection(
tenant2_doc_ids
), "Tenant 2 document IDs should not be in the response"
# Assert that the contents are correct
for doc in response1.tool_result:
assert doc["content"] == "Tenant 1 Document Content"
# User 2 sends a message and gets a response
response2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
message="What is in Tenant 2's documents?",
user_performing_action=test_user2,
)
# Assert that the search tool was used
assert response2.tool_name == "run_search"
# Assert that the tool_result contains Tenant 2's documents
response_doc_ids = {doc["document_id"] for doc in response2.tool_result}
assert tenant2_doc_ids.issubset(
response_doc_ids
), "Not all Tenant 2 document IDs are in the response"
assert not response_doc_ids.intersection(
tenant1_doc_ids
), "Tenant 1 document IDs should not be in the response"
# Assert that the contents are correct
for doc in response2.tool_result:
assert doc["content"] == "Tenant 2 Document Content"
# User 1 tries to access Tenant 2's documents
response_cross = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
message="What is in Tenant 2's documents?",
user_performing_action=test_user1,
)
# Assert that the search tool was used
assert response_cross.tool_name == "run_search"
# Assert that the tool_result is empty or does not contain Tenant 2's documents
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result}
# Ensure none of Tenant 2's document IDs are in the response
assert not response_doc_ids.intersection(tenant2_doc_ids)
# Optionally, assert that tool_result is empty
# assert len(response_cross.tool_result) == 0
# User 2 tries to access Tenant 1's documents
response_cross2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
message="What is in Tenant 1's documents?",
user_performing_action=test_user2,
)
# Assert that the search tool was used
assert response_cross2.tool_name == "run_search"
# Assert that the tool_result is empty or does not contain Tenant 1's documents
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result}
# Ensure none of Tenant 1's document IDs are in the response
assert not response_doc_ids.intersection(tenant1_doc_ids)
# Optionally, assert that tool_result is empty
# assert len(response_cross2.tool_result) == 0

View File

@@ -0,0 +1,41 @@
from danswer.configs.constants import DocumentSource
from danswer.db.enums import AccessType
from danswer.db.models import UserRole
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.tenant import TenantManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
TenantManager.create("tenant_dev", "test@test.com")
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.verify_role(test_user, UserRole.ADMIN)
test_credential = CredentialManager.create(
name="admin_test_credential",
source=DocumentSource.FILE,
curator_public=False,
user_performing_action=test_user,
)
test_connector = ConnectorManager.create(
name="admin_test_connector",
source=DocumentSource.FILE,
is_public=False,
user_performing_action=test_user,
)
test_cc_pair = CCPairManager.create(
connector_id=test_connector.id,
credential_id=test_credential.id,
name="admin_test_cc_pair",
access_type=AccessType.PRIVATE,
user_performing_action=test_user,
)
CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=test_user)

View File

@@ -9,7 +9,7 @@ export function SignInButton({
authType: AuthType;
}) {
let button;
if (authType === "google_oauth") {
if (authType === "google_oauth" || authType === "cloud") {
button = (
<div className="mx-auto flex">
<div className="my-auto mr-2">
@@ -42,7 +42,7 @@ export function SignInButton({
return (
<a
className="mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={authorizeUrl}
>
{button}

View File

@@ -78,7 +78,7 @@ const Page = async ({
<HealthCheckBanner />
</div>
<div>
<div className="flex flex-col w-full justify-center">
{authUrl && authTypeMetadata && (
<>
<h2 className="text-center text-xl text-strong font-bold">
@@ -92,6 +92,26 @@ const Page = async ({
</>
)}
{authTypeMetadata?.authType === "cloud" && (
<div className="mt-4 w-full justify-center">
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-gray-300"></div>
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-gray-300"></div>
</div>
<EmailPasswordForm shouldVerify={true} />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
Create an account
</Link>
</Text>
</div>
</div>
)}
{authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96">
<div className="flex">

View File

@@ -4,6 +4,7 @@ import {
getCurrentUserSS,
getAuthTypeMetadataSS,
AuthTypeMetadata,
getAuthUrlSS,
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm";
@@ -11,6 +12,8 @@ import { Card, Title, Text } from "@tremor/react";
import Link from "next/link";
import { Logo } from "@/components/Logo";
import { CLOUD_ENABLED } from "@/lib/constants";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
const Page = async () => {
// catch cases where the backend is completely unreachable here
@@ -26,9 +29,6 @@ const Page = async () => {
} catch (e) {
console.log(`Some fetch failed for the login page - ${e}`);
}
if (CLOUD_ENABLED) {
return redirect("/auth/login");
}
// simply take the user to the home page if Auth is disabled
if (authTypeMetadata?.authType === "disabled") {
@@ -42,44 +42,56 @@ const Page = async () => {
}
return redirect("/auth/waiting-on-verification");
}
const cloud = authTypeMetadata?.authType === "cloud";
// only enable this page if basic login is enabled
if (authTypeMetadata?.authType !== "basic") {
if (authTypeMetadata?.authType !== "basic" && !cloud) {
return redirect("/");
}
let authUrl: string | null = null;
if (cloud && authTypeMetadata) {
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
}
return (
<main>
<div className="absolute top-10x w-full">
<HealthCheckBanner />
</div>
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div>
<Logo height={64} width={64} className="mx-auto w-fit" />
<AuthFlowContainer>
<HealthCheckBanner />
<Card className="mt-4 w-96">
<div className="flex">
<Title className="mb-2 mx-auto font-bold">
Sign Up for Danswer
</Title>
</div>
<EmailPasswordForm
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
/>
<>
<div className="absolute top-10x w-full"></div>
<div className="flex w-full flex-col justify-center">
<h2 className="text-center text-xl text-strong font-bold">
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
</h2>
<div className="flex">
<Text className="mt-4 mx-auto">
Already have an account?{" "}
<Link href="/auth/login" className="text-link font-medium">
Log In
</Link>
</Text>
{cloud && authUrl && (
<div className="w-full justify-center">
<SignInButton authorizeUrl={authUrl} authType="cloud" />
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-background-300"></div>
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-background-300"></div>
</div>
</div>
</Card>
)}
<EmailPasswordForm
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
/>
<div className="flex">
<Text className="mt-4 mx-auto">
Already have an account?{" "}
<Link href="/auth/login" className="text-link font-medium">
Log In
</Link>
</Text>
</div>
</div>
</div>
</main>
</>
</AuthFlowContainer>
);
};

View File

@@ -7,7 +7,7 @@ export default function AuthFlowContainer({
}) {
return (
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<div className="w-full max-w-md bg-black p-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<Logo width={70} height={70} />
{children}
</div>

View File

@@ -1,4 +1,10 @@
export type AuthType = "disabled" | "basic" | "google_oauth" | "oidc" | "saml";
export type AuthType =
| "disabled"
| "basic"
| "google_oauth"
| "oidc"
| "saml"
| "cloud";
export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
export const HEADER_HEIGHT = "h-16";

View File

@@ -2,7 +2,7 @@ import { cookies } from "next/headers";
import { User } from "./types";
import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType } from "./constants";
import { AuthType, SERVER_SIDE_ONLY__CLOUD_ENABLED } from "./constants";
export interface AuthTypeMetadata {
authType: AuthType;
@@ -18,7 +18,15 @@ export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
const data: { auth_type: string; requires_verification: boolean } =
await res.json();
const authType = data.auth_type as AuthType;
let authType: AuthType;
// Override fasapi users auth so we can use both
if (SERVER_SIDE_ONLY__CLOUD_ENABLED) {
authType = "cloud";
} else {
authType = data.auth_type as AuthType;
}
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
// Danswer in an un-authenticated state
@@ -87,6 +95,9 @@ export const getAuthUrlSS = async (
case "google_oauth": {
return await getGoogleOAuthUrlSS();
}
case "cloud": {
return await getGoogleOAuthUrlSS();
}
case "saml": {
return await getSAMLAuthUrlSS();
}