mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-27 04:35:50 +00:00
Compare commits
23 Commits
v0.24.0-cl
...
gating
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d66bdd061 | ||
|
|
8f67f1715c | ||
|
|
3b365509e2 | ||
|
|
022cbdfccf | ||
|
|
ebec6f6b10 | ||
|
|
1cad9c7b3d | ||
|
|
b4e975013c | ||
|
|
dd26f92206 | ||
|
|
4d00ec45ad | ||
|
|
1a81c67a67 | ||
|
|
04f965e656 | ||
|
|
277d37e0ee | ||
|
|
3cd260131b | ||
|
|
ad21ee0e9a | ||
|
|
c7dc0e9af0 | ||
|
|
75c5de802b | ||
|
|
c39f590d0d | ||
|
|
82a9fda846 | ||
|
|
842d4ab2a8 | ||
|
|
cddcec4ea4 | ||
|
|
09dd7b424c | ||
|
|
a2fd8d5e0a | ||
|
|
802dc00f78 |
@@ -233,35 +233,60 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if (
|
||||
user_count == 0
|
||||
or user_create.email in get_default_admin_user_emails()
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
current_tenant_id.reset(token)
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
@@ -319,7 +344,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
self.database = tenant_user_db # type: ignore
|
||||
|
||||
oauth_account_dict = {
|
||||
"oauth_name": oauth_name,
|
||||
|
||||
@@ -4,7 +4,6 @@ import time
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
@@ -79,6 +78,7 @@ def on_task_prerun(
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
tenant_id: str | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
@@ -91,7 +91,7 @@ def on_task_postrun(
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
@@ -110,7 +110,17 @@ def on_task_postrun(
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
@@ -118,7 +128,7 @@ def on_task_postrun(
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
@@ -171,7 +181,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
# decide some initial startup settings based on the celery worker's hostname
|
||||
# (set at the command line)
|
||||
# (set at the command line)'
|
||||
|
||||
hostname = sender.hostname
|
||||
if hostname.startswith("light"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
@@ -182,166 +193,155 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
elif hostname.startswith("indexing"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# TODO: why is this necessary for the indexer to do?
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
for tenant_id in tenant_ids:
|
||||
# TODO: why is this necessary for the indexer to do?
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice(
|
||||
"Running a first inference to warm up embedding model"
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
else:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
r = get_redis_client()
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
|
||||
time.monotonic()
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
time_start = time.monotonic()
|
||||
logger.notice("Redis: Readiness check starting.")
|
||||
while True:
|
||||
# Log all the locks in Redis
|
||||
all_locks = r.keys("*")
|
||||
logger.notice(f"Current Redis locks: {all_locks}")
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return # Exit the function for secondary workers
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
time_start = time.monotonic()
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
r = get_redis_client()
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
sender.primary_worker_lock = lock
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
@@ -367,14 +367,15 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
@@ -449,17 +450,18 @@ def on_setup_logging(
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||
# Requires the Hub component
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
super().__init__(worker, **kwargs)
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
@@ -478,42 +480,58 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not worker.primary_worker_lock:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
lock: redis.lock.Lock = worker.primary_worker_lock
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be computer sleep or a clock change."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
task_logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.error("Primary worker lock: Acquire failed!")
|
||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
worker.primary_worker_lock = lock
|
||||
except Exception:
|
||||
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.error(f"Error in periodic task: {e}")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
@@ -583,14 +601,14 @@ tasks_to_schedule = [
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
|
||||
@@ -31,7 +31,10 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -44,7 +47,7 @@ def _get_deletion_status(
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
@@ -54,9 +57,14 @@ def _get_deletion_status(
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
)
|
||||
if not deletion_task:
|
||||
return None
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@ from danswer.redis.redis_pool import get_redis_client
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
||||
r = get_redis_client()
|
||||
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
|
||||
@@ -51,10 +51,10 @@ logger = setup_logger()
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_indexing(tenant_id: str | None) -> int | None:
|
||||
def check_for_indexing(*, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
@@ -64,7 +64,10 @@ def check_for_indexing(tenant_id: str | None) -> int | None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
|
||||
return None
|
||||
else:
|
||||
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
@@ -367,7 +370,7 @@ def connector_indexing_task(
|
||||
attempt = None
|
||||
n_final_progress = 0
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
|
||||
@@ -38,8 +38,8 @@ logger = setup_logger()
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_pruning(tenant_id: str | None) -> None:
|
||||
r = get_redis_client()
|
||||
def check_for_pruning(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -204,7 +204,7 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
@@ -66,6 +67,8 @@ from danswer.utils.variable_functionality import (
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@@ -74,11 +77,11 @@ from danswer.utils.variable_functionality import noop_fallback
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_vespa_sync_task(tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -640,7 +643,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False
|
||||
"""
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: redis.lock.Lock = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
|
||||
@@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
|
||||
@@ -136,6 +136,7 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -160,6 +161,9 @@ class AuthType(str, Enum):
|
||||
OIDC = "oidc"
|
||||
SAML = "saml"
|
||||
|
||||
# google auth and basic
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
CHAT = "Chat"
|
||||
|
||||
@@ -41,7 +41,7 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
# # for testing purposes, rate limiting is written to fall back to a simpler
|
||||
# # rate limiting approach when redis is not available
|
||||
# r = get_redis_client()
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# for attempt in range(max_retries):
|
||||
# try:
|
||||
|
||||
@@ -241,8 +241,7 @@ def create_credential(
|
||||
curator_public=credential_data.curator_public,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush() # This ensures the credential gets an ID
|
||||
|
||||
db_session.flush() # This ensures the credential gets an IDcredentials
|
||||
_relate_credential_to_user_groups__no_commit(
|
||||
db_session=db_session,
|
||||
credential_id=credential.id,
|
||||
|
||||
@@ -295,30 +295,32 @@ async def get_async_session_with_tenant(
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
else:
|
||||
current_tenant_id.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
# Establish a raw connection without starting a transaction
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
|
||||
# Execute SET search_path outside of any transaction
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
# Optionally verify the search_path was set correctly
|
||||
cursor.execute("SHOW search_path")
|
||||
cursor.fetchone()
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Proceed to create a session using the connection
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
@@ -332,6 +334,18 @@ def get_session_with_tenant(
|
||||
cursor.close()
|
||||
|
||||
|
||||
def set_search_path_on_checkout(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
logger.debug(
|
||||
f"Set search_path to {tenant_id} for connection {connection_record}"
|
||||
)
|
||||
|
||||
|
||||
def get_session_generator_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.models import Notification
|
||||
from danswer.db.models import User
|
||||
@@ -54,7 +55,9 @@ def get_notification_by_id(
|
||||
notif = db_session.get(Notification, notification_id)
|
||||
if not notif:
|
||||
raise ValueError(f"No notification found with id {notification_id}")
|
||||
if notif.user_id != user_id:
|
||||
if notif.user_id != user_id and not (
|
||||
notif.user_id is None and user.role == UserRole.ADMIN
|
||||
):
|
||||
raise PermissionError(
|
||||
f"User {user_id} is not authorized to access notification {notification_id}"
|
||||
)
|
||||
|
||||
@@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self) -> None:
|
||||
self.redis_client = get_redis_client()
|
||||
tenant_id = current_tenant_id.get()
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
|
||||
@@ -79,6 +79,8 @@ def _get_answer_stream_processor(
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
answer_style_configs: AnswerStyleConfig,
|
||||
) -> StreamProcessor:
|
||||
print("ANSWERR STYES")
|
||||
print(answer_style_configs.__dict__)
|
||||
if answer_style_configs.citation_config:
|
||||
return build_citation_processor(
|
||||
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
|
||||
|
||||
@@ -226,6 +226,7 @@ def process_model_tokens(
|
||||
hold_quote = ""
|
||||
|
||||
for token in tokens:
|
||||
print(f"Token: {token}")
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
|
||||
@@ -269,7 +269,7 @@ def get_application() -> FastAPI:
|
||||
# Server logs this during auth setup verification step
|
||||
pass
|
||||
|
||||
elif AUTH_TYPE == AuthType.BASIC:
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
@@ -301,7 +301,7 @@ def get_application() -> FastAPI:
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import functools
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
@@ -14,6 +17,72 @@ from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class TenantRedis(redis.Redis):
|
||||
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tenant_id: str = tenant_id
|
||||
|
||||
def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview:
|
||||
prefix: str = f"{self.tenant_id}:"
|
||||
if isinstance(key, str):
|
||||
if key.startswith(prefix):
|
||||
return key
|
||||
else:
|
||||
return prefix + key
|
||||
elif isinstance(key, bytes):
|
||||
prefix_bytes = prefix.encode()
|
||||
if key.startswith(prefix_bytes):
|
||||
return key
|
||||
else:
|
||||
return prefix_bytes + key
|
||||
elif isinstance(key, memoryview):
|
||||
key_bytes = key.tobytes()
|
||||
prefix_bytes = prefix.encode()
|
||||
if key_bytes.startswith(prefix_bytes):
|
||||
return key
|
||||
else:
|
||||
return memoryview(prefix_bytes + key_bytes)
|
||||
else:
|
||||
raise TypeError(f"Unsupported key type: {type(key)}")
|
||||
|
||||
def _prefix_method(self, method: Callable) -> Callable:
|
||||
@functools.wraps(method)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if "name" in kwargs:
|
||||
kwargs["name"] = self._prefixed(kwargs["name"])
|
||||
elif len(args) > 0:
|
||||
args = (self._prefixed(args[0]),) + args[1:]
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def __getattribute__(self, item: str) -> Any:
|
||||
original_attr = super().__getattribute__(item)
|
||||
methods_to_wrap = [
|
||||
"lock",
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
"hset",
|
||||
"hget",
|
||||
"getset",
|
||||
"scan_iter",
|
||||
"owned",
|
||||
"reacquire",
|
||||
"create_lock",
|
||||
"startswith",
|
||||
] # Add all methods that need prefixing
|
||||
if item in methods_to_wrap and callable(original_attr):
|
||||
return self._prefix_method(original_attr)
|
||||
return original_attr
|
||||
|
||||
|
||||
class RedisPool:
|
||||
@@ -32,8 +101,10 @@ class RedisPool:
|
||||
def _init_pool(self) -> None:
|
||||
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
|
||||
|
||||
def get_client(self) -> Redis:
|
||||
return redis.Redis(connection_pool=self._pool)
|
||||
def get_client(self, tenant_id: str | None) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = "public"
|
||||
return TenantRedis(tenant_id, connection_pool=self._pool)
|
||||
|
||||
@staticmethod
|
||||
def create_pool(
|
||||
@@ -84,8 +155,8 @@ class RedisPool:
|
||||
redis_pool = RedisPool()
|
||||
|
||||
|
||||
def get_redis_client() -> Redis:
|
||||
return redis_pool.get_client()
|
||||
def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_client(tenant_id)
|
||||
|
||||
|
||||
# # Usage example
|
||||
|
||||
@@ -24,6 +24,7 @@ from danswer.db.connector_credential_pair import (
|
||||
)
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import current_tenant_id
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -90,6 +91,7 @@ def get_cc_pair_full_info(
|
||||
cc_pair_id: int,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> CCPairFullInfo:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
@@ -136,6 +138,7 @@ def get_cc_pair_full_info(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
num_docs_indexed=documents_indexed,
|
||||
is_editable_for_current_user=is_editable_for_current_user,
|
||||
@@ -231,6 +234,7 @@ def prune_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers pruning on a particular cc_pair immediately"""
|
||||
|
||||
@@ -246,7 +250,7 @@ def prune_cc_pair(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if rcp.is_pruning(db_session, r):
|
||||
raise HTTPException(
|
||||
|
||||
@@ -482,10 +482,11 @@ def get_connector_indexing_status(
|
||||
get_editable: bool = Query(
|
||||
False, description="If true, return editable document sets"
|
||||
),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
indexing_statuses: list[ConnectorIndexingStatus] = []
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# NOTE: If the connector is deleting behind the scenes,
|
||||
# accessing cc_pairs can be inconsistent and members like
|
||||
@@ -606,6 +607,7 @@ def get_connector_indexing_status(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
is_deletable=check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair,
|
||||
@@ -683,15 +685,18 @@ def create_connector_with_mock_credential(
|
||||
connector_response = create_connector(
|
||||
db_session=db_session, connector_data=connector_data
|
||||
)
|
||||
|
||||
mock_credential = CredentialBase(
|
||||
credential_json={}, admin_public=True, source=connector_data.source
|
||||
)
|
||||
credential = create_credential(
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
access_type = (
|
||||
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
||||
)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -775,7 +780,7 @@ def connector_run_once(
|
||||
"""Used to trigger indexing on a set of cc_pairs associated with a
|
||||
single connector."""
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
connector_id = run_info.connector_id
|
||||
specified_credential_ids = run_info.credential_ids
|
||||
|
||||
@@ -54,6 +54,7 @@ def fetch_settings(
|
||||
Postgres calls"""
|
||||
general_settings = load_settings()
|
||||
user_notifications = get_reindex_notification(user, db_session)
|
||||
product_gating_notification = get_product_gating_notification(db_session)
|
||||
|
||||
try:
|
||||
kv_store = get_kv_store()
|
||||
@@ -61,11 +62,27 @@ def fetch_settings(
|
||||
except KvKeyNotFoundError:
|
||||
needs_reindexing = False
|
||||
|
||||
return UserSettings(
|
||||
print("product_gating_notification", product_gating_notification)
|
||||
# TODO: Clean up
|
||||
print("response is ", [product_gating_notification])
|
||||
response = UserSettings(
|
||||
**general_settings.model_dump(),
|
||||
notifications=user_notifications,
|
||||
notifications=[product_gating_notification]
|
||||
if product_gating_notification
|
||||
else user_notifications,
|
||||
needs_reindexing=needs_reindexing,
|
||||
)
|
||||
print("act is ", response)
|
||||
return response
|
||||
|
||||
|
||||
def get_product_gating_notification(db_session: Session) -> Notification | None:
|
||||
notification = get_notifications(
|
||||
user=None,
|
||||
notif_type=NotificationType.TRIAL_ENDS_TWO_DAYS,
|
||||
db_session=db_session,
|
||||
)
|
||||
return Notification.from_model(notification[0]) if notification else None
|
||||
|
||||
|
||||
def get_reindex_notification(
|
||||
|
||||
@@ -42,7 +42,9 @@ global_version.set_ee()
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
|
||||
def sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
@@ -50,7 +52,7 @@ def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None
|
||||
cc_pair_id: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
@@ -59,7 +61,7 @@ def sync_external_group_permissions_task(
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, tenant_id: str | None
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
@@ -72,7 +74,7 @@ def perform_ttl_management_task(
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
|
||||
def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
@@ -89,7 +91,7 @@ def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
|
||||
def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
@@ -106,7 +108,7 @@ def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
|
||||
name="check_ttl_management_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(tenant_id: str | None) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
@@ -130,7 +132,7 @@ def check_ttl_management_task(tenant_id: str | None) -> None:
|
||||
name="autogenerate_usage_report_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
@@ -179,7 +181,7 @@ for tenant_id in tenant_ids:
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||
"kwargs": {"tenant_id": tenant_id}, # Pass tenant_id as a keyword argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
|
||||
@@ -8,6 +8,7 @@ from danswer.auth.users import User
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.notification import create_notification
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.setup import setup_danswer
|
||||
@@ -87,12 +88,17 @@ def gate_product(
|
||||
1) User has ended free trial without adding payment method
|
||||
2) User's card has declined
|
||||
"""
|
||||
token = current_tenant_id.set(current_tenant_id.get())
|
||||
tenant_id = product_gating_request.tenant_id
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
settings.product_gating = product_gating_request.product_gating
|
||||
store_settings(settings)
|
||||
|
||||
if product_gating_request.notification:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_notification(None, product_gating_request.notification, db_session)
|
||||
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.server.settings.models import GatingType
|
||||
|
||||
|
||||
@@ -15,6 +16,7 @@ class CreateTenantRequest(BaseModel):
|
||||
class ProductGatingRequest(BaseModel):
|
||||
tenant_id: str
|
||||
product_gating: GatingType
|
||||
notification: NotificationType | None = None
|
||||
|
||||
|
||||
class BillingInformation(BaseModel):
|
||||
|
||||
@@ -313,7 +313,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432"
|
||||
- "5433"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ export interface Settings {
|
||||
export enum NotificationType {
|
||||
PERSONA_SHARED = "persona_shared",
|
||||
REINDEX_NEEDED = "reindex_needed",
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending",
|
||||
}
|
||||
|
||||
export interface Notification {
|
||||
|
||||
@@ -9,7 +9,7 @@ export function SignInButton({
|
||||
authType: AuthType;
|
||||
}) {
|
||||
let button;
|
||||
if (authType === "google_oauth") {
|
||||
if (authType === "google_oauth" || authType === "cloud") {
|
||||
button = (
|
||||
<div className="mx-auto flex">
|
||||
<div className="my-auto mr-2">
|
||||
@@ -42,7 +42,7 @@ export function SignInButton({
|
||||
|
||||
return (
|
||||
<a
|
||||
className="mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
|
||||
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
|
||||
href={authorizeUrl}
|
||||
>
|
||||
{button}
|
||||
|
||||
@@ -78,7 +78,7 @@ const Page = async ({
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="flex flex-col w-full justify-center">
|
||||
{authUrl && authTypeMetadata && (
|
||||
<>
|
||||
<h2 className="text-center text-xl text-strong font-bold">
|
||||
@@ -92,6 +92,26 @@ const Page = async ({
|
||||
</>
|
||||
)}
|
||||
|
||||
{authTypeMetadata?.authType === "cloud" && (
|
||||
<div className="mt-4 w-full justify-center">
|
||||
<div className="flex items-center w-full my-4">
|
||||
<div className="flex-grow border-t border-gray-300"></div>
|
||||
<span className="px-4 text-gray-500">or</span>
|
||||
<div className="flex-grow border-t border-gray-300"></div>
|
||||
</div>
|
||||
<EmailPasswordForm shouldVerify={true} />
|
||||
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
Don't have an account?{" "}
|
||||
<Link href="/auth/signup" className="text-link font-medium">
|
||||
Create an account
|
||||
</Link>
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{authTypeMetadata?.authType === "basic" && (
|
||||
<Card className="mt-4 w-96">
|
||||
<div className="flex">
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
@@ -13,7 +13,7 @@ export const POST = async (request: NextRequest) => {
|
||||
}
|
||||
|
||||
// Delete cookies only if cloud is enabled (jwt auth)
|
||||
if (CLOUD_ENABLED) {
|
||||
if (NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
|
||||
const cookieOptions = {
|
||||
path: "/",
|
||||
|
||||
@@ -4,13 +4,14 @@ import {
|
||||
getCurrentUserSS,
|
||||
getAuthTypeMetadataSS,
|
||||
AuthTypeMetadata,
|
||||
getAuthUrlSS,
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||
import { Card, Title, Text } from "@tremor/react";
|
||||
import { Text } from "@tremor/react";
|
||||
import Link from "next/link";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { SignInButton } from "../login/SignInButton";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
|
||||
const Page = async () => {
|
||||
// catch cases where the backend is completely unreachable here
|
||||
@@ -26,9 +27,6 @@ const Page = async () => {
|
||||
} catch (e) {
|
||||
console.log(`Some fetch failed for the login page - ${e}`);
|
||||
}
|
||||
if (CLOUD_ENABLED) {
|
||||
return redirect("/auth/login");
|
||||
}
|
||||
|
||||
// simply take the user to the home page if Auth is disabled
|
||||
if (authTypeMetadata?.authType === "disabled") {
|
||||
@@ -42,44 +40,56 @@ const Page = async () => {
|
||||
}
|
||||
return redirect("/auth/waiting-on-verification");
|
||||
}
|
||||
const cloud = authTypeMetadata?.authType === "cloud";
|
||||
|
||||
// only enable this page if basic login is enabled
|
||||
if (authTypeMetadata?.authType !== "basic") {
|
||||
if (authTypeMetadata?.authType !== "basic" && !cloud) {
|
||||
return redirect("/");
|
||||
}
|
||||
|
||||
let authUrl: string | null = null;
|
||||
if (cloud && authTypeMetadata) {
|
||||
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
|
||||
}
|
||||
|
||||
return (
|
||||
<main>
|
||||
<div className="absolute top-10x w-full">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
||||
<div>
|
||||
<Logo height={64} width={64} className="mx-auto w-fit" />
|
||||
<AuthFlowContainer>
|
||||
<HealthCheckBanner />
|
||||
|
||||
<Card className="mt-4 w-96">
|
||||
<div className="flex">
|
||||
<Title className="mb-2 mx-auto font-bold">
|
||||
Sign Up for Danswer
|
||||
</Title>
|
||||
</div>
|
||||
<EmailPasswordForm
|
||||
isSignup
|
||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||
/>
|
||||
<>
|
||||
<div className="absolute top-10x w-full"></div>
|
||||
<div className="flex w-full flex-col justify-center">
|
||||
<h2 className="text-center text-xl text-strong font-bold">
|
||||
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
|
||||
</h2>
|
||||
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
Already have an account?{" "}
|
||||
<Link href="/auth/login" className="text-link font-medium">
|
||||
Log In
|
||||
</Link>
|
||||
</Text>
|
||||
{cloud && authUrl && (
|
||||
<div className="w-full justify-center">
|
||||
<SignInButton authorizeUrl={authUrl} authType="cloud" />
|
||||
<div className="flex items-center w-full my-4">
|
||||
<div className="flex-grow border-t border-background-300"></div>
|
||||
<span className="px-4 text-gray-500">or</span>
|
||||
<div className="flex-grow border-t border-background-300"></div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
<EmailPasswordForm
|
||||
isSignup
|
||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||
/>
|
||||
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
Already have an account?{" "}
|
||||
<Link href="/auth/login" className="text-link font-medium">
|
||||
Log In
|
||||
</Link>
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</>
|
||||
</AuthFlowContainer>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -174,20 +174,6 @@ export default async function RootLayout({
|
||||
process.env.THEME_IS_DARK?.toLowerCase() === "true" ? "dark" : ""
|
||||
}`}
|
||||
>
|
||||
{productGating === GatingType.PARTIAL && (
|
||||
<div className="fixed top-0 left-0 right-0 z-50 bg-warning-100 text-warning-900 p-2 text-center">
|
||||
<p className="text-sm font-medium">
|
||||
Your account is pending payment!{" "}
|
||||
<a
|
||||
href="/admin/cloud-settings"
|
||||
className="font-bold underline hover:text-warning-700 transition-colors"
|
||||
>
|
||||
Update your billing information
|
||||
</a>{" "}
|
||||
or access will be suspended soon.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
<UserProvider>
|
||||
<ProviderContextProvider>
|
||||
<SettingsProvider settings={combinedSettings}>
|
||||
|
||||
@@ -27,7 +27,6 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
|
||||
const authTypeMetadata = results[0] as AuthTypeMetadata | null;
|
||||
const user = results[1] as User | null;
|
||||
console.log("authTypeMetadata", authTypeMetadata);
|
||||
const authDisabled = authTypeMetadata?.authType === "disabled";
|
||||
const requiresVerification = authTypeMetadata?.requiresVerification;
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ export default function AuthFlowContainer({
|
||||
}) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
|
||||
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
|
||||
<div className="w-full max-w-md bg-black p-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
|
||||
<Logo width={70} height={70} />
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -15,6 +15,7 @@ export function AnnouncementBanner() {
|
||||
settings?.settings.notifications || []
|
||||
);
|
||||
|
||||
console.log("notifications", localNotifications);
|
||||
useEffect(() => {
|
||||
const filteredNotifications = (
|
||||
settings?.settings.notifications || []
|
||||
@@ -32,7 +33,7 @@ export function AnnouncementBanner() {
|
||||
const handleDismiss = async (notificationId: number) => {
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/settings/notifications/${notificationId}/dismiss`,
|
||||
`/api/notifications/${notificationId}/dismiss`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
@@ -61,12 +62,12 @@ export function AnnouncementBanner() {
|
||||
{localNotifications
|
||||
.filter((notification) => !notification.dismissed)
|
||||
.map((notification) => {
|
||||
if (notification.notif_type == "reindex") {
|
||||
return (
|
||||
<div
|
||||
key={notification.id}
|
||||
className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto"
|
||||
>
|
||||
return (
|
||||
<div
|
||||
key={notification.id}
|
||||
className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto"
|
||||
>
|
||||
{notification.notif_type == "reindex" ? (
|
||||
<p className="text-center">
|
||||
Your index is out of date - we strongly recommend updating
|
||||
your search settings.{" "}
|
||||
@@ -77,24 +78,29 @@ export function AnnouncementBanner() {
|
||||
Update here
|
||||
</Link>
|
||||
</p>
|
||||
<button
|
||||
onClick={() => handleDismiss(notification.id)}
|
||||
className="absolute top-0 right-0 mt-2 mr-2"
|
||||
aria-label="Dismiss"
|
||||
>
|
||||
<CustomTooltip
|
||||
showTick
|
||||
citation
|
||||
delay={100}
|
||||
content="Dismiss"
|
||||
) : notification.notif_type == "two_day_trial_ending" ? (
|
||||
<p className="text-center">
|
||||
Your trial is ending soon - submit your billing information to
|
||||
continue using Danswer.{" "}
|
||||
<Link
|
||||
href="/admin/cloud-settings"
|
||||
className="ml-2 underline cursor-pointer"
|
||||
>
|
||||
<XIcon className="h-5 w-5" />
|
||||
</CustomTooltip>
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
Update here
|
||||
</Link>
|
||||
</p>
|
||||
) : null}
|
||||
<button
|
||||
onClick={() => handleDismiss(notification.id)}
|
||||
className="absolute top-0 right-0 mt-2 mr-2"
|
||||
aria-label="Dismiss"
|
||||
>
|
||||
<CustomTooltip showTick citation delay={100} content="Dismiss">
|
||||
<XIcon className="h-5 w-5" />
|
||||
</CustomTooltip>
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
export type AuthType = "disabled" | "basic" | "google_oauth" | "oidc" | "saml";
|
||||
export type AuthType =
|
||||
| "disabled"
|
||||
| "basic"
|
||||
| "google_oauth"
|
||||
| "oidc"
|
||||
| "saml"
|
||||
| "cloud";
|
||||
|
||||
export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
|
||||
export const HEADER_HEIGHT = "h-16";
|
||||
|
||||
@@ -2,7 +2,7 @@ import { cookies } from "next/headers";
|
||||
import { User } from "./types";
|
||||
import { buildUrl } from "./utilsSS";
|
||||
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
|
||||
import { AuthType } from "./constants";
|
||||
import { AuthType, SERVER_SIDE_ONLY__CLOUD_ENABLED } from "./constants";
|
||||
|
||||
export interface AuthTypeMetadata {
|
||||
authType: AuthType;
|
||||
@@ -18,7 +18,15 @@ export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
|
||||
|
||||
const data: { auth_type: string; requires_verification: boolean } =
|
||||
await res.json();
|
||||
const authType = data.auth_type as AuthType;
|
||||
|
||||
let authType: AuthType;
|
||||
|
||||
// Override fasapi users auth so we can use both
|
||||
if (SERVER_SIDE_ONLY__CLOUD_ENABLED) {
|
||||
authType = "cloud";
|
||||
} else {
|
||||
authType = data.auth_type as AuthType;
|
||||
}
|
||||
|
||||
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
|
||||
// Danswer in an un-authenticated state
|
||||
@@ -87,6 +95,9 @@ export const getAuthUrlSS = async (
|
||||
case "google_oauth": {
|
||||
return await getGoogleOAuthUrlSS();
|
||||
}
|
||||
case "cloud": {
|
||||
return await getGoogleOAuthUrlSS();
|
||||
}
|
||||
case "saml": {
|
||||
return await getSAMLAuthUrlSS();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user