mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-22 01:46:47 +00:00
Compare commits
9 Commits
v3.3.0-clo
...
jamison/ti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b198eb433 | ||
|
|
524ddb6e68 | ||
|
|
3dbf62ba02 | ||
|
|
ec97ed0a73 | ||
|
|
7688718fcf | ||
|
|
08df1da283 | ||
|
|
ab845026fd | ||
|
|
2fb0860529 | ||
|
|
e4adbee2af |
@@ -1,17 +1,44 @@
|
||||
"""Captcha verification for user registration."""
|
||||
"""Captcha verification for user registration.
|
||||
|
||||
Two flows share this module:
|
||||
|
||||
1. Email/password signup. The token is posted with the signup body and
|
||||
verified inline by ``UserManager.create``.
|
||||
|
||||
2. Google OAuth signup. The OAuth callback request originates from Google
|
||||
as a browser redirect, so we cannot attach a header or body field to it
|
||||
at that moment. Instead the frontend verifies a reCAPTCHA token BEFORE
|
||||
redirecting to Google and we set a signed HttpOnly cookie. The cookie
|
||||
is sent automatically on the callback request, where middleware checks
|
||||
it. ``issue_captcha_cookie_value`` / ``validate_captcha_cookie_value``
|
||||
handle the HMAC signing + expiry.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import CAPTCHA_COOKIE_TTL_SECONDS
|
||||
from onyx.configs.app_configs import CAPTCHA_ENABLED
|
||||
from onyx.configs.app_configs import RECAPTCHA_SCORE_THRESHOLD
|
||||
from onyx.configs.app_configs import RECAPTCHA_SECRET_KEY
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECAPTCHA_VERIFY_URL = "https://www.google.com/recaptcha/api/siteverify"
|
||||
CAPTCHA_COOKIE_NAME = "onyx_captcha_verified"
|
||||
|
||||
# Google v3 tokens expire server-side at ~2 minutes, so 120s is the max useful
|
||||
# replay window — after that Google would reject the token anyway.
|
||||
_REPLAY_CACHE_TTL_SECONDS = 120
|
||||
_REPLAY_KEY_PREFIX = "captcha:replay:"
|
||||
|
||||
|
||||
class CaptchaVerificationError(Exception):
|
||||
@@ -34,6 +61,51 @@ def is_captcha_enabled() -> bool:
|
||||
return CAPTCHA_ENABLED and bool(RECAPTCHA_SECRET_KEY)
|
||||
|
||||
|
||||
def _replay_cache_key(token: str) -> str:
|
||||
"""Avoid storing the raw token in Redis — hash it first."""
|
||||
digest = hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
return f"{_REPLAY_KEY_PREFIX}{digest}"
|
||||
|
||||
|
||||
async def _reserve_token_or_raise(token: str) -> None:
|
||||
"""SETNX a token fingerprint. If another caller already claimed it within
|
||||
the TTL, reject as a replay. Fails open on Redis errors — losing replay
|
||||
protection is strictly better than hard-failing legitimate registrations
|
||||
if Redis blips."""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
claimed = await redis.set(
|
||||
_replay_cache_key(token),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=_REPLAY_CACHE_TTL_SECONDS,
|
||||
)
|
||||
if not claimed:
|
||||
logger.warning("Captcha replay detected: token already used")
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: token already used"
|
||||
)
|
||||
except CaptchaVerificationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Captcha replay cache error (failing open): {e}")
|
||||
|
||||
|
||||
async def _release_token(token: str) -> None:
|
||||
"""Unclaim a previously-reserved token so a retry with the same still-valid
|
||||
token is not blocked. Called when WE fail (network error talking to
|
||||
Google), not when Google rejects the token — Google rejections mean the
|
||||
token is permanently invalid and must stay claimed."""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(_replay_cache_key(token))
|
||||
except Exception as e:
|
||||
# Worst case: the user must wait up to 120s before the TTL expires
|
||||
# on its own and they can retry. Still strictly better than failing
|
||||
# open on the reservation side.
|
||||
logger.error(f"Captcha replay cache release error (ignored): {e}")
|
||||
|
||||
|
||||
async def verify_captcha_token(
|
||||
token: str,
|
||||
expected_action: str = "signup",
|
||||
@@ -54,6 +126,11 @@ async def verify_captcha_token(
|
||||
if not token:
|
||||
raise CaptchaVerificationError("Captcha token is required")
|
||||
|
||||
# Claim the token first so a concurrent replay of the same value cannot
|
||||
# slip through the Google round-trip window. Done BEFORE calling Google
|
||||
# because even a still-valid token should only redeem once.
|
||||
await _reserve_token_or_raise(token)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
@@ -76,31 +153,113 @@ async def verify_captcha_token(
|
||||
f"Captcha verification failed: {', '.join(error_codes)}"
|
||||
)
|
||||
|
||||
# For reCAPTCHA v3, also check the score
|
||||
if result.score is not None:
|
||||
if result.score < RECAPTCHA_SCORE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: suspicious activity detected"
|
||||
)
|
||||
# Require v3 score. Google's public test secret returns no score
|
||||
# — that path must not be active in prod since it skips the only
|
||||
# human-vs-bot signal. A missing score here means captcha is
|
||||
# misconfigured (test secret in prod, or a v2 response slipped in
|
||||
# via an action mismatch).
|
||||
if result.score is None:
|
||||
logger.warning(
|
||||
"Captcha verification failed: siteverify returned no score (likely test secret in prod)"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: missing score"
|
||||
)
|
||||
|
||||
# Optionally verify the action matches
|
||||
if result.action and result.action != expected_action:
|
||||
logger.warning(
|
||||
f"Captcha action mismatch: {result.action} != {expected_action}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: action mismatch"
|
||||
)
|
||||
if result.score < RECAPTCHA_SCORE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: suspicious activity detected"
|
||||
)
|
||||
|
||||
if result.action and result.action != expected_action:
|
||||
logger.warning(
|
||||
f"Captcha action mismatch: {result.action} != {expected_action}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: action mismatch"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Captcha verification passed: score={result.score}, action={result.action}"
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Captcha API request failed: {e}")
|
||||
# In case of API errors, we might want to allow registration
|
||||
# to prevent blocking legitimate users. This is a policy decision.
|
||||
except CaptchaVerificationError:
|
||||
# Definitively-bad token (Google rejected it, score too low, action
|
||||
# mismatch). Keep the reservation so the same token cannot be
|
||||
# retried elsewhere during the TTL window.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Anything else — network failure, JSON decode error, Pydantic
|
||||
# validation error on an unexpected siteverify response shape — is
|
||||
# OUR inability to verify the token, not proof the token is bad.
|
||||
# Release the reservation so the user can retry with the same
|
||||
# still-valid token instead of being locked out for ~120s.
|
||||
logger.error(f"Captcha verification failed unexpectedly: {e}")
|
||||
await _release_token(token)
|
||||
raise CaptchaVerificationError("Captcha verification service unavailable")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth pre-redirect cookie helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cookie_signing_key() -> bytes:
|
||||
"""Derive a dedicated HMAC key from USER_AUTH_SECRET.
|
||||
|
||||
Using a separate derivation keeps the captcha cookie signature from
|
||||
being interchangeable with any other token that reuses USER_AUTH_SECRET.
|
||||
"""
|
||||
return hashlib.sha256(
|
||||
f"onyx-captcha-cookie-v1::{USER_AUTH_SECRET}".encode("utf-8")
|
||||
).digest()
|
||||
|
||||
|
||||
def issue_captcha_cookie_value(now: int | None = None) -> str:
|
||||
"""Produce an opaque cookie value encoding 'verified until <expiry>'.
|
||||
|
||||
Format: ``<expiry_epoch>.<hex_hmac>``. The presence of a valid
|
||||
unexpired signature proves the browser solved a captcha challenge
|
||||
recently on this origin.
|
||||
"""
|
||||
issued_at = now if now is not None else int(time.time())
|
||||
expiry = issued_at + CAPTCHA_COOKIE_TTL_SECONDS
|
||||
sig = hmac.new(
|
||||
_cookie_signing_key(), str(expiry).encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
return f"{expiry}.{sig}"
|
||||
|
||||
|
||||
def validate_captcha_cookie_value(value: str | None) -> bool:
|
||||
"""Return True if the cookie value has a valid unexpired signature.
|
||||
|
||||
The cookie is NOT a JWT — it's a minimal two-field format produced by
|
||||
``issue_captcha_cookie_value``:
|
||||
|
||||
<expiry_epoch_seconds>.<hex_hmac_sha256>
|
||||
|
||||
We split on the first ``.``, parse the expiry as an integer, recompute
|
||||
the HMAC over the expiry string using the key derived from
|
||||
USER_AUTH_SECRET, and compare with ``hmac.compare_digest`` to avoid
|
||||
timing leaks. No base64, no JSON, no claims — anything fancier would
|
||||
be overkill for a short-lived "verified recently" cookie.
|
||||
"""
|
||||
if not value:
|
||||
return False
|
||||
parts = value.split(".", 1)
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
expiry_str, provided_sig = parts
|
||||
try:
|
||||
expiry = int(expiry_str)
|
||||
except ValueError:
|
||||
return False
|
||||
if expiry < int(time.time()):
|
||||
return False
|
||||
expected_sig = hmac.new(
|
||||
_cookie_signing_key(), str(expiry).encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(expected_sig, provided_sig)
|
||||
|
||||
@@ -100,6 +100,7 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.multi_llm import LLMTimeoutError
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
@@ -1277,6 +1278,32 @@ def _run_models(
|
||||
else:
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
elif isinstance(item, LLMTimeoutError):
|
||||
model_llm = setup.llms[model_idx]
|
||||
error_msg = (
|
||||
"The LLM took too long to respond. "
|
||||
"If you're running a local model, try increasing the "
|
||||
"LLM_SOCKET_READ_TIMEOUT environment variable "
|
||||
"(current default: 120 seconds)."
|
||||
)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(type(item), item, item.__traceback__)
|
||||
)
|
||||
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="CONNECTION_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
elif isinstance(item, Exception):
|
||||
# Yield a tagged error for this model but keep the other models running.
|
||||
# Do NOT decrement models_remaining — _run_model's finally always posts
|
||||
|
||||
@@ -180,6 +180,13 @@ DISPOSABLE_EMAIL_DOMAINS_URL = os.environ.get(
|
||||
"https://disposable.github.io/disposable-email-domains/domains.json",
|
||||
)
|
||||
|
||||
# Captcha cookie TTL — how long a verified captcha token remains valid in
|
||||
# the browser cookie before the user has to solve another challenge. Sized
|
||||
# to comfortably cover one Google OAuth round-trip (typically <10s) while
|
||||
# keeping the replay window tight. 120s also matches Google's own v3 token
|
||||
# lifetime, so a paired-up cookie + token never outlive each other.
|
||||
CAPTCHA_COOKIE_TTL_SECONDS = int(os.environ.get("CAPTCHA_COOKIE_TTL_SECONDS", "120"))
|
||||
|
||||
# OAuth Login Flow
|
||||
# Used for both Google OAuth2 and OIDC flows
|
||||
OAUTH_CLIENT_ID = (
|
||||
@@ -1069,6 +1076,12 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||
|
||||
# When true, GET /users is restricted to callers with READ_USERS so non-admins
|
||||
# cannot enumerate the tenant directory. Off by default to preserve sharing UX.
|
||||
USER_DIRECTORY_ADMIN_ONLY = (
|
||||
os.environ.get("USER_DIRECTORY_ADMIN_ONLY", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Limit on number of users a free trial tenant can invite (cloud only)
|
||||
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))
|
||||
|
||||
|
||||
@@ -519,18 +519,45 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
logger.debug(f"Settings of index {self._index_name} updated successfully.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def get_settings(self) -> dict[str, Any]:
|
||||
def get_settings(
|
||||
self,
|
||||
include_defaults: bool = False,
|
||||
flat_settings: bool = False,
|
||||
pretty: bool = False,
|
||||
human: bool = False,
|
||||
) -> tuple[dict[str, Any], dict[str, Any] | None]:
|
||||
"""Gets the settings of the index.
|
||||
|
||||
Args:
|
||||
include_defaults: Whether to include default settings which have not
|
||||
been explicitly set. Defaults to False.
|
||||
flat_settings: Whether to return settings in flat format vs nested
|
||||
dictionaries. Defaults to False.
|
||||
pretty: Whether to pretty-format the returned JSON response.
|
||||
Defaults to False.
|
||||
human: Whether to return statistics in human-readable format.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
The settings of the index.
|
||||
The settings of the index, and optionally the default settings. If
|
||||
include_defaults is False, the default settings will be None.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error getting the settings of the index.
|
||||
"""
|
||||
logger.debug(f"Getting settings of index {self._index_name}.")
|
||||
response = self._client.indices.get_settings(index=self._index_name)
|
||||
return response[self._index_name]["settings"]
|
||||
params = {
|
||||
"include_defaults": str(include_defaults).lower(),
|
||||
"flat_settings": str(flat_settings).lower(),
|
||||
"pretty": str(pretty).lower(),
|
||||
"human": str(human).lower(),
|
||||
}
|
||||
response = self._client.indices.get_settings(
|
||||
index=self._index_name, params=params
|
||||
)
|
||||
return response[self._index_name]["settings"], response[self._index_name].get(
|
||||
"defaults", None
|
||||
)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def open_index(self) -> None:
|
||||
|
||||
@@ -290,7 +290,11 @@ def litellm_exception_to_error_msg(
|
||||
error_code = "BUDGET_EXCEEDED"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, Timeout):
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
error_msg = (
|
||||
"The LLM took too long to respond. "
|
||||
"If you're running a local model, try increasing the "
|
||||
"LLM_SOCKET_READ_TIMEOUT environment variable (current default: 120 seconds)."
|
||||
)
|
||||
error_code = "CONNECTION_ERROR"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, APIError):
|
||||
@@ -743,7 +747,13 @@ def model_is_reasoning_model(model_name: str, model_provider: str) -> bool:
|
||||
model_name,
|
||||
)
|
||||
if model_obj and "supports_reasoning" in model_obj:
|
||||
return model_obj["supports_reasoning"]
|
||||
reasoning = model_obj["supports_reasoning"]
|
||||
if reasoning is None:
|
||||
logger.error(
|
||||
f"Cannot find reasoning for name={model_name} and provider={model_provider}"
|
||||
)
|
||||
reasoning = False
|
||||
return reasoning
|
||||
|
||||
# Fallback: try using litellm.supports_reasoning() for newer models
|
||||
try:
|
||||
|
||||
@@ -64,6 +64,8 @@ from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.hooks.registry import validate_registry
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth.captcha_api import CaptchaCookieMiddleware
|
||||
from onyx.server.auth.captcha_api import router as captcha_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
@@ -524,6 +526,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, mcp_admin_router)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, pat_router)
|
||||
include_router_with_global_prefix_prepended(application, captcha_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
include_auth_router_with_prefix(
|
||||
@@ -655,6 +658,10 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Gate the OAuth callback on a signed captcha cookie set by the frontend
|
||||
# before the Google redirect. No-op unless is_captcha_enabled() is true
|
||||
# (requires CAPTCHA_ENABLED=true and RECAPTCHA_SECRET_KEY set).
|
||||
application.add_middleware(CaptchaCookieMiddleware)
|
||||
if LOG_ENDPOINT_LATENCY:
|
||||
add_latency_logging_middleware(application, logger)
|
||||
|
||||
|
||||
0
backend/onyx/server/auth/__init__.py
Normal file
0
backend/onyx/server/auth/__init__.py
Normal file
118
backend/onyx/server/auth/captcha_api.py
Normal file
118
backend/onyx/server/auth/captcha_api.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""API + middleware for the reCAPTCHA pre-OAuth cookie flow.
|
||||
|
||||
The frontend solves a reCAPTCHA v3 challenge before clicking "Continue
|
||||
with Google", POSTs the token to ``/auth/captcha/oauth-verify``, and the
|
||||
backend verifies it with Google and sets a signed HttpOnly cookie. The
|
||||
cookie rides along on the subsequent Google OAuth callback redirect,
|
||||
where ``CaptchaCookieMiddleware`` checks it. Without this cookie flow
|
||||
the OAuth callback is un-gated because Google (not our frontend) issues
|
||||
the request and we cannot attach a header at the redirect hop.
|
||||
|
||||
Email/password signup has its own captcha enforcement inside
|
||||
``UserManager.create``, so this module only gates the OAuth callback.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
|
||||
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
|
||||
from onyx.auth.captcha import CaptchaVerificationError
|
||||
from onyx.auth.captcha import is_captcha_enabled
|
||||
from onyx.auth.captcha import issue_captcha_cookie_value
|
||||
from onyx.auth.captcha import validate_captcha_cookie_value
|
||||
from onyx.auth.captcha import verify_captcha_token
|
||||
from onyx.configs.app_configs import CAPTCHA_COOKIE_TTL_SECONDS
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import onyx_error_to_json_response
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/auth/captcha", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
# Only the OAuth callback is gated here. /auth/register has its own
|
||||
# captcha enforcement inside UserManager.create via the body's
|
||||
# captcha_token field — the cookie layer is specifically for the OAuth
|
||||
# redirect that our frontend cannot attach a header to.
|
||||
GUARDED_OAUTH_CALLBACK_PATHS = frozenset({"/auth/oauth/callback"})
|
||||
|
||||
|
||||
class OAuthCaptchaVerifyRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class OAuthCaptchaVerifyResponse(BaseModel):
|
||||
ok: bool
|
||||
|
||||
|
||||
@router.post("/oauth-verify")
|
||||
async def verify_oauth_captcha(
|
||||
body: OAuthCaptchaVerifyRequest,
|
||||
response: Response,
|
||||
) -> OAuthCaptchaVerifyResponse:
|
||||
"""Verify a reCAPTCHA token and set the OAuth-redirect cookie.
|
||||
|
||||
If captcha enforcement is off the endpoint is a no-op so the frontend
|
||||
doesn't block on dormant deployments.
|
||||
"""
|
||||
if not is_captcha_enabled():
|
||||
return OAuthCaptchaVerifyResponse(ok=True)
|
||||
|
||||
try:
|
||||
await verify_captcha_token(body.token, expected_action="oauth")
|
||||
except CaptchaVerificationError as exc:
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHORIZED, str(exc))
|
||||
|
||||
response.set_cookie(
|
||||
key=CAPTCHA_COOKIE_NAME,
|
||||
value=issue_captcha_cookie_value(),
|
||||
max_age=CAPTCHA_COOKIE_TTL_SECONDS,
|
||||
secure=True,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
path="/",
|
||||
)
|
||||
return OAuthCaptchaVerifyResponse(ok=True)
|
||||
|
||||
|
||||
class CaptchaCookieMiddleware(BaseHTTPMiddleware):
|
||||
"""Reject OAuth-callback requests that don't carry a valid captcha cookie.
|
||||
|
||||
No-op when ``is_captcha_enabled()`` is false so self-hosted and dev
|
||||
deployments pass through transparently.
|
||||
"""
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
# Skip OPTIONS so CORS preflight is never blocked.
|
||||
is_guarded_callback = (
|
||||
request.method != "OPTIONS"
|
||||
and request.url.path in GUARDED_OAUTH_CALLBACK_PATHS
|
||||
and is_captcha_enabled()
|
||||
)
|
||||
if is_guarded_callback:
|
||||
cookie_value = request.cookies.get(CAPTCHA_COOKIE_NAME)
|
||||
if not validate_captcha_cookie_value(cookie_value):
|
||||
return onyx_error_to_json_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Captcha challenge required. Refresh the page and try again.",
|
||||
)
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# One-time-use cookie: after the OAuth callback has been served, clear
|
||||
# it so the remaining TTL cannot be replayed (e.g. via browser
|
||||
# back-button) to re-enter the callback without a fresh challenge.
|
||||
if is_guarded_callback:
|
||||
response.delete_cookie(CAPTCHA_COOKIE_NAME, path="/")
|
||||
return response
|
||||
@@ -13,7 +13,6 @@ from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
PUBLIC_ENDPOINT_SPECS = [
|
||||
# built-in documentation functions
|
||||
("/openapi.json", {"GET", "HEAD"}),
|
||||
@@ -35,6 +34,10 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
("/auth/refresh", {"POST"}),
|
||||
("/auth/register", {"POST"}),
|
||||
("/auth/login", {"POST"}),
|
||||
# reCAPTCHA pre-OAuth challenge — user is not yet authenticated when
|
||||
# they solve it, and the endpoint's own handler enforces the only
|
||||
# thing that matters (valid Google siteverify response).
|
||||
("/auth/captcha/oauth-verify", {"POST"}),
|
||||
("/auth/logout", {"POST"}),
|
||||
("/auth/forgot-password", {"POST"}),
|
||||
("/auth/reset-password", {"POST"}),
|
||||
|
||||
236
backend/onyx/server/manage/invite_rate_limit.py
Normal file
236
backend/onyx/server/manage/invite_rate_limit.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Redis-backed rate limits for admin invite + remove-invited-user endpoints.
|
||||
|
||||
Defends against compromised-admin invite-spam and email-bomb abuse that
|
||||
nginx IP-keyed `limit_req` cannot stop (per-pod counters in multi-replica
|
||||
deployments, trivial IP rotation). Counters live in tenant-prefixed Redis
|
||||
so multi-pod api-server instances share state and per-admin / per-tenant
|
||||
quotas are enforced cluster-wide.
|
||||
|
||||
Check+increment is performed in a single Redis Lua script so two
|
||||
concurrent replicas cannot both pass the pre-check and both increment
|
||||
past the limit. When Redis is unavailable (e.g. Onyx Lite deployments
|
||||
where Redis is an opt-in `--profile redis` service), the rate limiter
|
||||
fails open with a logged warning so core invite flows continue to work.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from redis import Redis
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
from redis.exceptions import RedisError
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SECONDS_PER_MINUTE = 60
|
||||
_SECONDS_PER_DAY = 24 * 60 * 60
|
||||
|
||||
# Rate limits apply to trial tenants only (enforced at call site). Paid
|
||||
# tenants bypass entirely — their guardrails are seat limits and the
|
||||
# per-admin lifetime counter. Self-hosted / Lite deployments fail open
|
||||
# when Redis is unavailable. Values are sized against the trial lifetime
|
||||
# cap `NUM_FREE_TRIAL_USER_INVITES=10` and the invite→remove→invite
|
||||
# bypass attack: per-day caps stay tight so a compromised or scripted
|
||||
# trial admin cannot exceed the lifetime cap even across window rolls,
|
||||
# and per-minute caps block burst automation while leaving headroom for
|
||||
# a human typing emails by hand.
|
||||
_INVITE_ADMIN_PER_MIN = 3
|
||||
_INVITE_ADMIN_PER_DAY = 10
|
||||
_INVITE_TENANT_PER_DAY = 15
|
||||
_REMOVE_ADMIN_PER_MIN = 3
|
||||
_REMOVE_ADMIN_PER_DAY = 30
|
||||
|
||||
# Per-admin buckets are accidentally safe without an explicit tenant
|
||||
# prefix because admin UUIDs are globally unique. The tenant/day bucket
|
||||
# MUST embed the tenant_id directly: TenantRedis.__getattribute__ only
|
||||
# prefixes commands in an explicit allowlist and `eval` is not on it, so
|
||||
# keys passed to the Lua script reach Redis bare. A key of
|
||||
# "ratelimit:invite_put:tenant:day" would be shared across every trial
|
||||
# tenant, exhausting one global counter for the whole cluster.
|
||||
_INVITE_PUT_ADMIN_MIN_KEY = "ratelimit:invite_put:admin:{user_id}:min"
|
||||
_INVITE_PUT_ADMIN_DAY_KEY = "ratelimit:invite_put:admin:{user_id}:day"
|
||||
_INVITE_PUT_TENANT_DAY_KEY = "ratelimit:invite_put:tenant:{tenant_id}:day"
|
||||
_INVITE_REMOVE_ADMIN_MIN_KEY = "ratelimit:invite_remove:admin:{user_id}:min"
|
||||
_INVITE_REMOVE_ADMIN_DAY_KEY = "ratelimit:invite_remove:admin:{user_id}:day"
|
||||
|
||||
# Atomic multi-bucket check+increment.
|
||||
# ARGV[1] = N (bucket count). For each bucket i=1..N, ARGV[2+(i-1)*3..4+(i-1)*3]
|
||||
# carry increment, limit, ttl. KEYS[i] is the bucket's Redis key.
|
||||
#
|
||||
# Buckets with limit <= 0 or increment <= 0 are skipped (a disabled tier).
|
||||
# On reject, returns the 1-indexed bucket number that failed so the caller
|
||||
# can report which scope tripped; on success returns 0. TTL is set with NX
|
||||
# semantics so pre-existing keys without a TTL are still given one, but
|
||||
# fresh increments do not reset the window (fixed-window, not sliding).
|
||||
_CHECK_AND_INCREMENT_SCRIPT = """
|
||||
local n = tonumber(ARGV[1])
|
||||
for i = 1, n do
|
||||
local key = KEYS[i]
|
||||
local increment = tonumber(ARGV[2 + (i - 1) * 3])
|
||||
local limit = tonumber(ARGV[3 + (i - 1) * 3])
|
||||
if limit > 0 and increment > 0 then
|
||||
local current = tonumber(redis.call('get', key)) or 0
|
||||
if current + increment > limit then
|
||||
return i
|
||||
end
|
||||
end
|
||||
end
|
||||
for i = 1, n do
|
||||
local key = KEYS[i]
|
||||
local increment = tonumber(ARGV[2 + (i - 1) * 3])
|
||||
local limit = tonumber(ARGV[3 + (i - 1) * 3])
|
||||
local ttl = tonumber(ARGV[4 + (i - 1) * 3])
|
||||
if limit > 0 and increment > 0 then
|
||||
redis.call('incrby', key, increment)
|
||||
redis.call('expire', key, ttl, 'NX')
|
||||
end
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Bucket:
|
||||
key: str
|
||||
limit: int
|
||||
ttl_seconds: int
|
||||
scope: str
|
||||
increment: int
|
||||
|
||||
|
||||
def _run_atomic(redis_client: Redis, buckets: list[_Bucket]) -> None:
|
||||
"""Run the check+increment Lua script. Raise OnyxError on rejection.
|
||||
|
||||
On Redis connection / timeout errors the rate limiter fails open: the
|
||||
request is allowed through and the failure is logged. This keeps the
|
||||
invite flow usable on Onyx Lite deployments (Redis is opt-in there)
|
||||
and during transient Redis outages in full deployments.
|
||||
"""
|
||||
if not buckets:
|
||||
return
|
||||
|
||||
keys = [b.key for b in buckets]
|
||||
argv: list[str] = [str(len(buckets))]
|
||||
for b in buckets:
|
||||
argv.extend([str(b.increment), str(b.limit), str(b.ttl_seconds)])
|
||||
|
||||
try:
|
||||
result = redis_client.eval( # type: ignore[no-untyped-call]
|
||||
_CHECK_AND_INCREMENT_SCRIPT,
|
||||
len(keys),
|
||||
*keys,
|
||||
*argv,
|
||||
)
|
||||
except (RedisConnectionError, RedisTimeoutError) as e:
|
||||
logger.warning(
|
||||
"Invite rate limiter skipped — Redis unavailable: %s. Rate limiting is disabled for this request.",
|
||||
e,
|
||||
)
|
||||
return
|
||||
except RedisError as e:
|
||||
logger.error(
|
||||
"Invite rate limiter Redis error, failing open: %s",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
failed_index = int(result) if isinstance(result, (int, str, bytes)) else 0
|
||||
if failed_index <= 0:
|
||||
return
|
||||
|
||||
failed_bucket = buckets[failed_index - 1]
|
||||
logger.warning(
|
||||
"Invite rate limit hit: scope=%s key=%s adding=%d limit=%d",
|
||||
failed_bucket.scope,
|
||||
failed_bucket.key,
|
||||
failed_bucket.increment,
|
||||
failed_bucket.limit,
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.RATE_LIMITED,
|
||||
f"Invite rate limit exceeded ({failed_bucket.scope}). Try again later.",
|
||||
)
|
||||
|
||||
|
||||
def enforce_invite_rate_limit(
|
||||
redis_client: Redis,
|
||||
admin_user_id: UUID | str,
|
||||
num_invites: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""Check+record invite quotas for an admin user within their tenant.
|
||||
|
||||
Three tiers. Daily tiers track invite volume (so bulk invite of 20
|
||||
users counts as 20); the minute tier tracks request cadence (so a
|
||||
single legitimate bulk call does not trip the burst guard while an
|
||||
attacker spamming single-email requests does).
|
||||
|
||||
Raises OnyxError(RATE_LIMITED) without recording if any tier would be
|
||||
exceeded, so repeated rejected attempts do not consume budget.
|
||||
`num_invites` MUST be the count of new invites the request will send
|
||||
(not total emails in the body — deduplicate already-invited first).
|
||||
Zero-invite calls still tick the minute bucket so probe-floods of
|
||||
already-invited emails cannot bypass the burst guard.
|
||||
"""
|
||||
user_key = str(admin_user_id)
|
||||
daily_increment = max(0, num_invites)
|
||||
buckets = [
|
||||
_Bucket(
|
||||
key=_INVITE_PUT_TENANT_DAY_KEY.format(tenant_id=tenant_id),
|
||||
limit=_INVITE_TENANT_PER_DAY,
|
||||
ttl_seconds=_SECONDS_PER_DAY,
|
||||
scope="tenant/day",
|
||||
increment=daily_increment,
|
||||
),
|
||||
_Bucket(
|
||||
key=_INVITE_PUT_ADMIN_DAY_KEY.format(user_id=user_key),
|
||||
limit=_INVITE_ADMIN_PER_DAY,
|
||||
ttl_seconds=_SECONDS_PER_DAY,
|
||||
scope="admin/day",
|
||||
increment=daily_increment,
|
||||
),
|
||||
_Bucket(
|
||||
key=_INVITE_PUT_ADMIN_MIN_KEY.format(user_id=user_key),
|
||||
limit=_INVITE_ADMIN_PER_MIN,
|
||||
ttl_seconds=_SECONDS_PER_MINUTE,
|
||||
scope="admin/minute",
|
||||
increment=1,
|
||||
),
|
||||
]
|
||||
_run_atomic(redis_client, buckets)
|
||||
|
||||
|
||||
def enforce_remove_invited_rate_limit(
|
||||
redis_client: Redis,
|
||||
admin_user_id: UUID | str,
|
||||
) -> None:
|
||||
"""Check+record remove-invited-user quotas for an admin user.
|
||||
|
||||
Two tiers: per-admin per-day and per-admin per-minute. Removal itself
|
||||
does not send email, so there is no tenant-wide cap — the goal is to
|
||||
detect the PUT→PATCH abuse pattern by throttling PATCHes to roughly
|
||||
the cadence of legitimate administrative mistake correction.
|
||||
"""
|
||||
user_key = str(admin_user_id)
|
||||
buckets = [
|
||||
_Bucket(
|
||||
key=_INVITE_REMOVE_ADMIN_DAY_KEY.format(user_id=user_key),
|
||||
limit=_REMOVE_ADMIN_PER_DAY,
|
||||
ttl_seconds=_SECONDS_PER_DAY,
|
||||
scope="admin/day",
|
||||
increment=1,
|
||||
),
|
||||
_Bucket(
|
||||
key=_INVITE_REMOVE_ADMIN_MIN_KEY.format(user_id=user_key),
|
||||
limit=_REMOVE_ADMIN_PER_MIN,
|
||||
ttl_seconds=_SECONDS_PER_MINUTE,
|
||||
scope="admin/minute",
|
||||
increment=1,
|
||||
),
|
||||
]
|
||||
_run_atomic(redis_client, buckets)
|
||||
@@ -30,7 +30,6 @@ from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import USER_DIRECTORY_ADMIN_ONLY
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
@@ -81,10 +82,15 @@ from onyx.db.users import get_total_filtered_users_count
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import get_user_counts_by_role_and_status
|
||||
from onyx.db.users import validate_user_role_update
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.redis.redis_pool import get_raw_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.features.projects.models import UserFileSnapshot
|
||||
from onyx.server.manage.invite_rate_limit import enforce_invite_rate_limit
|
||||
from onyx.server.manage.invite_rate_limit import enforce_remove_invited_rate_limit
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
from onyx.server.manage.models import BulkInviteResponse
|
||||
@@ -469,6 +475,12 @@ def bulk_invite_users(
|
||||
status_code=403,
|
||||
detail="You have hit your invite limit. Please upgrade for unlimited invites.",
|
||||
)
|
||||
enforce_invite_rate_limit(
|
||||
redis_client=get_redis_client(tenant_id=tenant_id),
|
||||
admin_user_id=current_user.id,
|
||||
num_invites=len(emails_needing_seats),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Check seat availability for new users
|
||||
if emails_needing_seats:
|
||||
@@ -529,10 +541,17 @@ def bulk_invite_users(
|
||||
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)
|
||||
def remove_invited_user(
|
||||
user_email: UserByEmail,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
current_user: User = Depends(
|
||||
require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
tenant_id = get_current_tenant_id()
|
||||
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
|
||||
enforce_remove_invited_rate_limit(
|
||||
redis_client=get_redis_client(tenant_id=tenant_id),
|
||||
admin_user_id=current_user.id,
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
@@ -672,15 +691,24 @@ def get_valid_domains(
|
||||
@router.get("/users", tags=PUBLIC_API_TAGS)
|
||||
def list_all_users_basic_info(
|
||||
include_api_keys: bool = False,
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserSnapshot]:
|
||||
if (
|
||||
USER_DIRECTORY_ADMIN_ONLY
|
||||
and Permission.READ_USERS not in get_effective_permissions(user)
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
users = get_all_users(db_session)
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.account_type != AccountType.BOT
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
MinimalUserSnapshot(id=u.id, email=u.email)
|
||||
for u in users
|
||||
if u.account_type != AccountType.BOT
|
||||
and (include_api_keys or not is_api_key_email_address(u.email))
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -986,11 +986,21 @@ async def search_chats(
|
||||
@router.post("/stop-chat-session/{chat_session_id}", tags=PUBLIC_API_TAGS)
|
||||
def stop_chat_session(
|
||||
chat_session_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Stop a chat session by setting a stop signal.
|
||||
This endpoint is called by the frontend when the user clicks the stop button.
|
||||
"""
|
||||
try:
|
||||
get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(OnyxErrorCode.SESSION_NOT_FOUND, "Chat session not found")
|
||||
|
||||
set_fence(chat_session_id, get_cache_backend(), True)
|
||||
return {"message": "Chat session stopped"}
|
||||
|
||||
@@ -5,7 +5,8 @@ Usage:
|
||||
source .venv/bin/activate
|
||||
python backend/scripts/debugging/opensearch/opensearch_debug.py --help
|
||||
python backend/scripts/debugging/opensearch/opensearch_debug.py list
|
||||
python backend/scripts/debugging/opensearch/opensearch_debug.py delete <index_name>
|
||||
python backend/scripts/debugging/opensearch/opensearch_debug.py delete
|
||||
<index_name>
|
||||
|
||||
Environment Variables:
|
||||
OPENSEARCH_HOST: OpenSearch host
|
||||
@@ -17,10 +18,11 @@ Dependencies:
|
||||
backend/shared_configs/configs.py
|
||||
backend/onyx/document_index/opensearch/client.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
@@ -61,6 +63,43 @@ def delete_index(client: OpenSearchIndexClient) -> None:
|
||||
print(f"Failed to delete index '{client._index_name}' for an unknown reason.")
|
||||
|
||||
|
||||
def get_settings(
|
||||
client: OpenSearchIndexClient,
|
||||
include_defaults: bool = False,
|
||||
flat_settings: bool = False,
|
||||
pretty: bool = False,
|
||||
human: bool = False,
|
||||
) -> None:
|
||||
settings, default_settings = client.get_settings(
|
||||
include_defaults=include_defaults,
|
||||
flat_settings=flat_settings,
|
||||
pretty=pretty,
|
||||
human=human,
|
||||
)
|
||||
print("Settings:")
|
||||
print(json.dumps(settings, indent=4))
|
||||
print("-" * 80)
|
||||
if default_settings:
|
||||
print("Default settings:")
|
||||
print(json.dumps(default_settings, indent=4))
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def set_settings(client: OpenSearchIndexClient, settings: dict[str, Any]) -> None:
|
||||
client.update_settings(settings)
|
||||
print(f"Updated settings for index '{client._index_name}'.")
|
||||
|
||||
|
||||
def open_index(client: OpenSearchIndexClient) -> None:
|
||||
client.open_index()
|
||||
print(f"Index '{client._index_name}' opened.")
|
||||
|
||||
|
||||
def close_index(client: OpenSearchIndexClient) -> None:
|
||||
client.close_index()
|
||||
print(f"Index '{client._index_name}' closed.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def add_standard_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
@@ -77,13 +116,19 @@ def main() -> None:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--username",
|
||||
help="OpenSearch username. If not provided, will fall back to OPENSEARCH_ADMIN_USERNAME, then prompt for input.",
|
||||
help=(
|
||||
"OpenSearch username. If not provided, will fall back to OPENSEARCH_ADMIN_USERNAME, then prompt for "
|
||||
"input."
|
||||
),
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_ADMIN_USERNAME", ""),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--password",
|
||||
help="OpenSearch password. If not provided, will fall back to OPENSEARCH_ADMIN_PASSWORD, then prompt for input.",
|
||||
help=(
|
||||
"OpenSearch password. If not provided, will fall back to OPENSEARCH_ADMIN_PASSWORD, then prompt for "
|
||||
"input."
|
||||
),
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_ADMIN_PASSWORD", ""),
|
||||
)
|
||||
@@ -118,6 +163,47 @@ def main() -> None:
|
||||
delete_parser = subparsers.add_parser("delete", help="Delete an index.")
|
||||
delete_parser.add_argument("index", help="Index name.", type=str)
|
||||
|
||||
get_settings_parser = subparsers.add_parser(
|
||||
"get", help="Get settings for an index."
|
||||
)
|
||||
get_settings_parser.add_argument("index", help="Index name.", type=str)
|
||||
get_settings_parser.add_argument(
|
||||
"--include-defaults",
|
||||
help="Include default settings.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
get_settings_parser.add_argument(
|
||||
"--flat-settings",
|
||||
help="Return settings in flat format.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
get_settings_parser.add_argument(
|
||||
"--pretty",
|
||||
help="Pretty-format the returned JSON response.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
get_settings_parser.add_argument(
|
||||
"--human",
|
||||
help="Return statistics in human-readable format.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
|
||||
set_settings_parser = subparsers.add_parser(
|
||||
"set", help="Set settings for an index."
|
||||
)
|
||||
set_settings_parser.add_argument("index", help="Index name.", type=str)
|
||||
set_settings_parser.add_argument("settings", help="Settings to set.", type=str)
|
||||
|
||||
open_index_parser = subparsers.add_parser("open", help="Open an index.")
|
||||
open_index_parser.add_argument("index", help="Index name.", type=str)
|
||||
|
||||
close_index_parser = subparsers.add_parser("close", help="Close an index.")
|
||||
close_index_parser.add_argument("index", help="Index name.", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not (host := args.host or input("Enter the OpenSearch host: ")):
|
||||
@@ -134,18 +220,19 @@ def main() -> None:
|
||||
sys.exit(1)
|
||||
print("Using AWS-managed OpenSearch: ", args.use_aws_managed_opensearch)
|
||||
print(f"MULTI_TENANT: {MULTI_TENANT}")
|
||||
print()
|
||||
|
||||
with (
|
||||
OpenSearchIndexClient(
|
||||
index_name=args.index,
|
||||
OpenSearchClient(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=(username, password),
|
||||
use_ssl=not args.no_ssl,
|
||||
verify_certs=not args.no_verify_certs,
|
||||
)
|
||||
if args.command == "delete"
|
||||
else OpenSearchClient(
|
||||
if args.command == "list"
|
||||
else OpenSearchIndexClient(
|
||||
index_name=args.index,
|
||||
host=host,
|
||||
port=port,
|
||||
auth=(username, password),
|
||||
@@ -161,6 +248,23 @@ def main() -> None:
|
||||
list_indices(client)
|
||||
elif args.command == "delete":
|
||||
delete_index(client)
|
||||
elif args.command == "get":
|
||||
get_settings(
|
||||
client,
|
||||
include_defaults=args.include_defaults,
|
||||
flat_settings=args.flat_settings,
|
||||
pretty=args.pretty,
|
||||
human=args.human,
|
||||
)
|
||||
elif args.command == "set":
|
||||
set_settings(client, json.loads(args.settings))
|
||||
elif args.command == "open":
|
||||
open_index(client)
|
||||
elif args.command == "close":
|
||||
close_index(client)
|
||||
else:
|
||||
print(f"Unknown command: {args.command}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -456,7 +456,7 @@ class TestOpenSearchClient:
|
||||
# Assert that the current number of replicas is not the desired test
|
||||
# number we are updating to.
|
||||
test_num_replicas = 0
|
||||
current_settings = test_client.get_settings()
|
||||
current_settings, _ = test_client.get_settings()
|
||||
assert current_settings["index"]["number_of_replicas"] != f"{test_num_replicas}"
|
||||
|
||||
# Under test.
|
||||
@@ -467,7 +467,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
current_settings = test_client.get_settings()
|
||||
current_settings, _ = test_client.get_settings()
|
||||
assert current_settings["index"]["number_of_replicas"] == f"{test_num_replicas}"
|
||||
|
||||
def test_update_settings_on_nonexistent_index(
|
||||
@@ -488,7 +488,7 @@ class TestOpenSearchClient:
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
current_settings = test_client.get_settings()
|
||||
current_settings, _ = test_client.get_settings()
|
||||
|
||||
# Postcondition.
|
||||
assert "index" in current_settings
|
||||
|
||||
@@ -183,3 +183,30 @@ def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
|
||||
"""Verify unknown IDs return 404."""
|
||||
response = _get_chat_session(str(uuid4()), basic_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def _stop_chat_session(chat_session_id: str, user: DATestUser) -> requests.Response:
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/chat/stop-chat-session/{chat_session_id}",
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def test_stop_chat_session_rejects_non_owner(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Non-owner callers must not be able to stop another user's chat session."""
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
# Owner can stop their own session.
|
||||
response = _stop_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# A different authenticated user must not be able to stop it.
|
||||
response = _stop_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Unknown session IDs are also rejected.
|
||||
response = _stop_chat_session(str(uuid4()), second_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
47
backend/tests/unit/onyx/auth/test_captcha_cookie.py
Normal file
47
backend/tests/unit/onyx/auth/test_captcha_cookie.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Unit tests for the OAuth reCAPTCHA cookie helpers in onyx.auth.captcha."""
|
||||
|
||||
from onyx.auth import captcha as captcha_module
|
||||
|
||||
|
||||
def test_issued_cookie_validates() -> None:
|
||||
"""A freshly issued cookie passes validation."""
|
||||
cookie = captcha_module.issue_captcha_cookie_value()
|
||||
assert captcha_module.validate_captcha_cookie_value(cookie) is True
|
||||
|
||||
|
||||
def test_validate_rejects_none() -> None:
|
||||
assert captcha_module.validate_captcha_cookie_value(None) is False
|
||||
|
||||
|
||||
def test_validate_rejects_empty_string() -> None:
|
||||
assert captcha_module.validate_captcha_cookie_value("") is False
|
||||
|
||||
|
||||
def test_validate_rejects_malformed_no_separator() -> None:
|
||||
assert captcha_module.validate_captcha_cookie_value("nodot") is False
|
||||
|
||||
|
||||
def test_validate_rejects_non_numeric_expiry() -> None:
|
||||
assert captcha_module.validate_captcha_cookie_value("notanumber.deadbeef") is False
|
||||
|
||||
|
||||
def test_validate_rejects_tampered_signature() -> None:
|
||||
"""Swapping the signature while keeping the expiry is rejected."""
|
||||
cookie = captcha_module.issue_captcha_cookie_value()
|
||||
expiry, _sig = cookie.split(".", 1)
|
||||
tampered = f"{expiry}.deadbeefdeadbeefdeadbeefdeadbeef"
|
||||
assert captcha_module.validate_captcha_cookie_value(tampered) is False
|
||||
|
||||
|
||||
def test_validate_rejects_expired_timestamp() -> None:
|
||||
"""An expiry in the past is rejected even with a valid signature."""
|
||||
cookie = captcha_module.issue_captcha_cookie_value(now=0)
|
||||
assert captcha_module.validate_captcha_cookie_value(cookie) is False
|
||||
|
||||
|
||||
def test_validate_rejects_modified_expiry() -> None:
|
||||
"""Bumping the expiry forward invalidates the signature."""
|
||||
cookie = captcha_module.issue_captcha_cookie_value()
|
||||
_expiry, sig = cookie.split(".", 1)
|
||||
bumped = f"99999999999.{sig}"
|
||||
assert captcha_module.validate_captcha_cookie_value(bumped) is False
|
||||
172
backend/tests/unit/onyx/auth/test_captcha_replay.py
Normal file
172
backend/tests/unit/onyx/auth/test_captcha_replay.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Unit tests for the reCAPTCHA token replay cache."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.auth import captcha as captcha_module
|
||||
from onyx.auth.captcha import _replay_cache_key
|
||||
from onyx.auth.captcha import _reserve_token_or_raise
|
||||
from onyx.auth.captcha import CaptchaVerificationError
|
||||
from onyx.auth.captcha import verify_captcha_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reserve_token_succeeds_on_first_use() -> None:
|
||||
"""First SETNX claims the token; no exception."""
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.set = AsyncMock(return_value=True)
|
||||
with patch.object(
|
||||
captcha_module,
|
||||
"get_async_redis_connection",
|
||||
AsyncMock(return_value=fake_redis),
|
||||
):
|
||||
await _reserve_token_or_raise("some-token")
|
||||
fake_redis.set.assert_awaited_once()
|
||||
await_args = fake_redis.set.await_args
|
||||
assert await_args is not None
|
||||
assert await_args.kwargs["nx"] is True
|
||||
assert await_args.kwargs["ex"] == 120
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reserve_token_rejects_replay() -> None:
|
||||
"""Second use of the same token within TTL → CaptchaVerificationError."""
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.set = AsyncMock(return_value=False)
|
||||
with patch.object(
|
||||
captcha_module,
|
||||
"get_async_redis_connection",
|
||||
AsyncMock(return_value=fake_redis),
|
||||
):
|
||||
with pytest.raises(CaptchaVerificationError, match="token already used"):
|
||||
await _reserve_token_or_raise("replayed-token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reserve_token_fails_open_on_redis_error() -> None:
|
||||
"""A Redis blip must NOT block legitimate registrations."""
|
||||
with patch.object(
|
||||
captcha_module,
|
||||
"get_async_redis_connection",
|
||||
AsyncMock(side_effect=RuntimeError("redis down")),
|
||||
):
|
||||
# No exception raised — replay protection is gracefully skipped.
|
||||
await _reserve_token_or_raise("any-token")
|
||||
|
||||
|
||||
def test_replay_cache_key_is_sha256_prefixed() -> None:
|
||||
"""The stored key never contains the raw token."""
|
||||
key = _replay_cache_key("raw-value")
|
||||
assert key.startswith("captcha:replay:")
|
||||
assert "raw-value" not in key
|
||||
# Length = prefix + 64 hex chars.
|
||||
assert len(key) == len("captcha:replay:") + 64
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reservation_released_when_google_unreachable() -> None:
|
||||
"""If Google's siteverify itself errors (our side, not the token's), the
|
||||
replay reservation must be released so the user can retry with the same
|
||||
still-valid token instead of getting 'already used' for 120s."""
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.set = AsyncMock(return_value=True)
|
||||
fake_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.post = AsyncMock(side_effect=httpx.ConnectError("network down"))
|
||||
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
|
||||
fake_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(
|
||||
captcha_module,
|
||||
"get_async_redis_connection",
|
||||
AsyncMock(return_value=fake_redis),
|
||||
),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=fake_client),
|
||||
):
|
||||
with pytest.raises(CaptchaVerificationError, match="service unavailable"):
|
||||
await verify_captcha_token("valid-token", expected_action="signup")
|
||||
|
||||
# The reservation was claimed and then released.
|
||||
fake_redis.set.assert_awaited_once()
|
||||
fake_redis.delete.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reservation_released_on_unexpected_response_shape() -> None:
|
||||
"""Non-HTTP errors during response parsing (malformed JSON, pydantic
|
||||
validation failure) also release the reservation — they mean WE couldn't
|
||||
verify the token, not that the token is definitively invalid."""
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.set = AsyncMock(return_value=True)
|
||||
fake_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Simulate Google returning something that json() still succeeds on but
|
||||
# fails RecaptchaResponse validation (e.g. success=true but with a wrong
|
||||
# shape that Pydantic rejects when coerced).
|
||||
fake_httpx_response = MagicMock()
|
||||
fake_httpx_response.raise_for_status = MagicMock()
|
||||
fake_httpx_response.json = MagicMock(side_effect=ValueError("not valid JSON"))
|
||||
fake_client = MagicMock()
|
||||
fake_client.post = AsyncMock(return_value=fake_httpx_response)
|
||||
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
|
||||
fake_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(
|
||||
captcha_module,
|
||||
"get_async_redis_connection",
|
||||
AsyncMock(return_value=fake_redis),
|
||||
),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=fake_client),
|
||||
):
|
||||
with pytest.raises(CaptchaVerificationError, match="service unavailable"):
|
||||
await verify_captcha_token("valid-token", expected_action="signup")
|
||||
|
||||
fake_redis.set.assert_awaited_once()
|
||||
fake_redis.delete.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reservation_kept_when_google_rejects_token() -> None:
|
||||
"""If Google itself says the token is invalid (success=false, or score
|
||||
too low), the reservation must NOT be released — that token is known-bad
|
||||
for its entire lifetime and shouldn't be retryable."""
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.set = AsyncMock(return_value=True)
|
||||
fake_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
fake_httpx_response = MagicMock()
|
||||
fake_httpx_response.raise_for_status = MagicMock()
|
||||
fake_httpx_response.json = MagicMock(
|
||||
return_value={
|
||||
"success": False,
|
||||
"error-codes": ["invalid-input-response"],
|
||||
}
|
||||
)
|
||||
fake_client = MagicMock()
|
||||
fake_client.post = AsyncMock(return_value=fake_httpx_response)
|
||||
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
|
||||
fake_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(
|
||||
captcha_module,
|
||||
"get_async_redis_connection",
|
||||
AsyncMock(return_value=fake_redis),
|
||||
),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=fake_client),
|
||||
):
|
||||
with pytest.raises(CaptchaVerificationError, match="invalid-input-response"):
|
||||
await verify_captcha_token("bad-token", expected_action="signup")
|
||||
|
||||
fake_redis.set.assert_awaited_once()
|
||||
fake_redis.delete.assert_not_awaited()
|
||||
78
backend/tests/unit/onyx/auth/test_captcha_require_score.py
Normal file
78
backend/tests/unit/onyx/auth/test_captcha_require_score.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Unit tests for the require-score check in verify_captcha_token."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth import captcha as captcha_module
|
||||
from onyx.auth.captcha import CaptchaVerificationError
|
||||
from onyx.auth.captcha import verify_captcha_token
|
||||
|
||||
|
||||
def _fake_httpx_client_returning(payload: dict) -> MagicMock:
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status = MagicMock()
|
||||
resp.json = MagicMock(return_value=payload)
|
||||
client = MagicMock()
|
||||
client.post = AsyncMock(return_value=resp)
|
||||
client.__aenter__ = AsyncMock(return_value=client)
|
||||
client.__aexit__ = AsyncMock(return_value=None)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_when_score_missing() -> None:
|
||||
"""Siteverify response with no score field is rejected outright —
|
||||
closes the accidental 'test secret in prod' bypass path."""
|
||||
client = _fake_httpx_client_returning(
|
||||
{"success": True, "hostname": "testkey.google.com"}
|
||||
)
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
|
||||
):
|
||||
with pytest.raises(CaptchaVerificationError, match="missing score"):
|
||||
await verify_captcha_token("test-token", expected_action="signup")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_when_score_present_and_above_threshold() -> None:
|
||||
"""Sanity check the happy path still works with the tighter score rule."""
|
||||
client = _fake_httpx_client_returning(
|
||||
{
|
||||
"success": True,
|
||||
"score": 0.9,
|
||||
"action": "signup",
|
||||
"hostname": "cloud.onyx.app",
|
||||
}
|
||||
)
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
|
||||
):
|
||||
# Should not raise.
|
||||
await verify_captcha_token("fresh-token", expected_action="signup")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_when_score_below_threshold() -> None:
|
||||
"""A score present but below threshold still rejects (existing behavior,
|
||||
guarding against regression from this PR's restructure)."""
|
||||
client = _fake_httpx_client_returning(
|
||||
{
|
||||
"success": True,
|
||||
"score": 0.1,
|
||||
"action": "signup",
|
||||
"hostname": "cloud.onyx.app",
|
||||
}
|
||||
)
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
|
||||
):
|
||||
with pytest.raises(
|
||||
CaptchaVerificationError, match="suspicious activity detected"
|
||||
):
|
||||
await verify_captcha_token("low-score-token", expected_action="signup")
|
||||
0
backend/tests/unit/onyx/server/auth/__init__.py
Normal file
0
backend/tests/unit/onyx/server/auth/__init__.py
Normal file
210
backend/tests/unit/onyx/server/auth/test_captcha_api.py
Normal file
210
backend/tests/unit/onyx/server/auth/test_captcha_api.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Unit tests for the reCAPTCHA OAuth verify endpoint + cookie middleware."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from onyx.auth.captcha import CaptchaVerificationError
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.server.auth import captcha_api as captcha_api_module
|
||||
from onyx.server.auth.captcha_api import CaptchaCookieMiddleware
|
||||
from onyx.server.auth.captcha_api import router as captcha_router
|
||||
|
||||
|
||||
def build_app_with_middleware() -> FastAPI:
|
||||
"""Minimal app with the middleware + router + fake OAuth callback route."""
|
||||
app = FastAPI()
|
||||
register_onyx_exception_handlers(app)
|
||||
app.add_middleware(CaptchaCookieMiddleware)
|
||||
app.include_router(captcha_router)
|
||||
|
||||
@app.get("/auth/oauth/callback")
|
||||
async def _oauth_callback() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/auth/register")
|
||||
async def _register() -> dict[str, str]:
|
||||
# /auth/register is NOT gated by this middleware (it has its own
|
||||
# captcha enforcement in UserManager.create). Used here to prove the
|
||||
# middleware only touches /auth/oauth/callback.
|
||||
return {"status": "created"}
|
||||
|
||||
@app.get("/me")
|
||||
async def _me() -> dict[str, str]:
|
||||
return {"status": "not-guarded"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ---------- /auth/captcha/oauth-verify endpoint ----------
|
||||
|
||||
|
||||
def test_verify_endpoint_returns_ok_when_captcha_disabled() -> None:
|
||||
"""Dormant mode: endpoint is a no-op, no cookie issued."""
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=False):
|
||||
res = client.post("/auth/captcha/oauth-verify", json={"token": "whatever"})
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"ok": True}
|
||||
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
|
||||
|
||||
assert CAPTCHA_COOKIE_NAME not in res.cookies
|
||||
|
||||
|
||||
def test_verify_endpoint_sets_cookie_on_success() -> None:
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with (
|
||||
patch.object(captcha_api_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(
|
||||
captcha_api_module,
|
||||
"verify_captcha_token",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
res = client.post("/auth/captcha/oauth-verify", json={"token": "valid-token"})
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"ok": True}
|
||||
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
|
||||
|
||||
assert CAPTCHA_COOKIE_NAME in res.cookies
|
||||
|
||||
|
||||
def test_verify_endpoint_raises_onyx_error_on_failure() -> None:
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with (
|
||||
patch.object(captcha_api_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(
|
||||
captcha_api_module,
|
||||
"verify_captcha_token",
|
||||
AsyncMock(
|
||||
side_effect=CaptchaVerificationError(
|
||||
"Captcha verification failed: invalid-input-response"
|
||||
)
|
||||
),
|
||||
),
|
||||
):
|
||||
res = client.post("/auth/captcha/oauth-verify", json={"token": "bad-token"})
|
||||
assert res.status_code == 403
|
||||
body = res.json()
|
||||
assert body["error_code"] == "UNAUTHORIZED"
|
||||
assert "invalid-input-response" in body["detail"]
|
||||
|
||||
|
||||
def test_verify_endpoint_rejects_missing_token() -> None:
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
res = client.post("/auth/captcha/oauth-verify", json={})
|
||||
# Pydantic validation failure from missing `token`.
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
# ---------- CaptchaCookieMiddleware ----------
|
||||
|
||||
|
||||
def test_middleware_passes_through_when_captcha_disabled() -> None:
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=False):
|
||||
res = client.get("/auth/oauth/callback")
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_middleware_blocks_oauth_callback_without_cookie() -> None:
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
|
||||
res = client.get("/auth/oauth/callback")
|
||||
assert res.status_code == 403
|
||||
body = res.json()
|
||||
assert body["error_code"] == "UNAUTHORIZED"
|
||||
assert "Captcha challenge required" in body["detail"]
|
||||
|
||||
|
||||
def test_middleware_allows_oauth_callback_with_valid_cookie() -> None:
|
||||
"""A correctly-signed unexpired cookie lets the OAuth callback through."""
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
|
||||
cookie_value = captcha_api_module.issue_captcha_cookie_value()
|
||||
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
|
||||
|
||||
res = client.get(
|
||||
"/auth/oauth/callback",
|
||||
cookies={CAPTCHA_COOKIE_NAME: cookie_value},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_middleware_clears_cookie_after_successful_callback() -> None:
|
||||
"""One-time-use: cookie is deleted after the callback has been served so
|
||||
a replayed callback URL cannot re-enter without a fresh challenge."""
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
|
||||
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
|
||||
cookie_value = captcha_api_module.issue_captcha_cookie_value()
|
||||
res = client.get(
|
||||
"/auth/oauth/callback",
|
||||
cookies={CAPTCHA_COOKIE_NAME: cookie_value},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
set_cookie = res.headers.get("set-cookie", "")
|
||||
# Starlette's delete_cookie emits an expired Max-Age=0 Set-Cookie for the name.
|
||||
assert CAPTCHA_COOKIE_NAME in set_cookie
|
||||
assert (
|
||||
"Max-Age=0" in set_cookie or 'expires="Thu, 01 Jan 1970' in set_cookie.lower()
|
||||
)
|
||||
|
||||
|
||||
def test_middleware_rejects_tampered_cookie() -> None:
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
|
||||
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
|
||||
res = client.get(
|
||||
"/auth/oauth/callback",
|
||||
cookies={CAPTCHA_COOKIE_NAME: "9999999999.deadbeef"},
|
||||
)
|
||||
assert res.status_code == 403
|
||||
|
||||
|
||||
def test_middleware_ignores_register_path() -> None:
|
||||
"""/auth/register has its own captcha enforcement in UserManager.create —
|
||||
the cookie middleware should NOT gate it."""
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
|
||||
res = client.post("/auth/register", json={})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_middleware_ignores_unrelated_paths() -> None:
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
|
||||
res = client.get("/me")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_middleware_skips_options_preflight() -> None:
|
||||
"""CORS preflight must pass through even without a cookie."""
|
||||
app = build_app_with_middleware()
|
||||
client = TestClient(app)
|
||||
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
|
||||
res = client.options("/auth/oauth/callback")
|
||||
# Not 403: preflight passed the captcha gate.
|
||||
assert res.status_code != 403
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -7,9 +7,12 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.manage.models import EmailInviteStatus
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.manage.users import bulk_invite_users
|
||||
from onyx.server.manage.users import remove_invited_user
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
|
||||
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
|
||||
@@ -21,12 +24,13 @@ def test_trial_tenant_cannot_exceed_invite_limit(*_mocks: None) -> None:
|
||||
emails = [f"user{i}@example.com" for i in range(6)]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
bulk_invite_users(emails=emails)
|
||||
bulk_invite_users(emails=emails, current_user=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "invite limit" in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.DEV_MODE", True)
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
|
||||
@@ -45,7 +49,7 @@ def test_trial_tenant_can_invite_within_limit(*_mocks: None) -> None:
|
||||
"""Trial tenants can invite users when under the limit."""
|
||||
emails = ["user1@example.com", "user2@example.com", "user3@example.com"]
|
||||
|
||||
result = bulk_invite_users(emails=emails)
|
||||
result = bulk_invite_users(emails=emails, current_user=MagicMock())
|
||||
|
||||
assert result.invited_count == 3
|
||||
assert result.email_invite_status == EmailInviteStatus.DISABLED
|
||||
@@ -60,6 +64,7 @@ _COMMON_PATCHES = [
|
||||
patch("onyx.server.manage.users.get_all_users", return_value=[]),
|
||||
patch("onyx.server.manage.users.write_invited_users", return_value=1),
|
||||
patch("onyx.server.manage.users.enforce_seat_limit"),
|
||||
patch("onyx.server.manage.users.enforce_invite_rate_limit"),
|
||||
]
|
||||
|
||||
|
||||
@@ -73,7 +78,7 @@ def _with_common_patches(fn: object) -> object:
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
|
||||
def test_email_invite_status_disabled(*_mocks: None) -> None:
|
||||
"""When email invites are disabled, status is disabled."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.DISABLED
|
||||
|
||||
@@ -83,7 +88,7 @@ def test_email_invite_status_disabled(*_mocks: None) -> None:
|
||||
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", False)
|
||||
def test_email_invite_status_not_configured(*_mocks: None) -> None:
|
||||
"""When email invites are enabled but no server is configured, status is not_configured."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.NOT_CONFIGURED
|
||||
|
||||
@@ -94,7 +99,7 @@ def test_email_invite_status_not_configured(*_mocks: None) -> None:
|
||||
@patch("onyx.server.manage.users.send_user_email_invite")
|
||||
def test_email_invite_status_sent(mock_send: MagicMock, *_mocks: None) -> None:
|
||||
"""When email invites are enabled and configured, status is sent."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
|
||||
|
||||
mock_send.assert_called_once()
|
||||
assert result.email_invite_status == EmailInviteStatus.SENT
|
||||
@@ -109,7 +114,123 @@ def test_email_invite_status_sent(mock_send: MagicMock, *_mocks: None) -> None:
|
||||
)
|
||||
def test_email_invite_status_send_failed(*_mocks: None) -> None:
|
||||
"""When email sending throws, status is send_failed and invite is still saved."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.SEND_FAILED
|
||||
assert result.invited_count == 1
|
||||
|
||||
|
||||
# --- trial-only rate limit gating tests ---
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.DEV_MODE", True)
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
|
||||
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=False)
|
||||
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.get_all_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
|
||||
@patch("onyx.server.manage.users.enforce_seat_limit")
|
||||
@patch(
|
||||
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda *_args: None,
|
||||
)
|
||||
def test_paid_tenant_bypasses_invite_rate_limit(
|
||||
_ee_fetch: MagicMock,
|
||||
_seat_limit: MagicMock,
|
||||
_write_invited: MagicMock,
|
||||
_get_all_users: MagicMock,
|
||||
_get_invited_users: MagicMock,
|
||||
_get_tenant_id: MagicMock,
|
||||
_is_trial: MagicMock,
|
||||
mock_rate_limit: MagicMock,
|
||||
) -> None:
|
||||
"""Paid tenants must not hit the invite rate limiter at all."""
|
||||
emails = [f"user{i}@example.com" for i in range(3)]
|
||||
bulk_invite_users(emails=emails, current_user=MagicMock())
|
||||
mock_rate_limit.assert_not_called()
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.DEV_MODE", True)
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
|
||||
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
|
||||
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.get_all_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
|
||||
@patch("onyx.server.manage.users.enforce_seat_limit")
|
||||
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 50)
|
||||
@patch(
|
||||
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda *_args: None,
|
||||
)
|
||||
def test_trial_tenant_hits_invite_rate_limit(
|
||||
_ee_fetch: MagicMock,
|
||||
_seat_limit: MagicMock,
|
||||
_write_invited: MagicMock,
|
||||
_get_all_users: MagicMock,
|
||||
_get_invited_users: MagicMock,
|
||||
_get_tenant_id: MagicMock,
|
||||
_is_trial: MagicMock,
|
||||
mock_rate_limit: MagicMock,
|
||||
) -> None:
|
||||
"""Trial tenants must flow through the invite rate limiter."""
|
||||
emails = [f"user{i}@example.com" for i in range(3)]
|
||||
bulk_invite_users(emails=emails, current_user=MagicMock())
|
||||
mock_rate_limit.assert_called_once()
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.enforce_remove_invited_rate_limit")
|
||||
@patch("onyx.server.manage.users.remove_user_from_invited_users", return_value=0)
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.DEV_MODE", True)
|
||||
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=False)
|
||||
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch(
|
||||
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda *_args: None,
|
||||
)
|
||||
def test_paid_tenant_bypasses_remove_invited_rate_limit(
|
||||
_ee_fetch: MagicMock,
|
||||
_get_tenant_id: MagicMock,
|
||||
_is_trial: MagicMock,
|
||||
_remove_from_invited: MagicMock,
|
||||
mock_rate_limit: MagicMock,
|
||||
) -> None:
|
||||
"""Paid tenants must not hit the remove-invited rate limiter at all."""
|
||||
remove_invited_user(
|
||||
user_email=UserByEmail(user_email="user@example.com"),
|
||||
current_user=MagicMock(),
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
mock_rate_limit.assert_not_called()
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.enforce_remove_invited_rate_limit")
|
||||
@patch("onyx.server.manage.users.remove_user_from_invited_users", return_value=0)
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.DEV_MODE", True)
|
||||
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
|
||||
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch(
|
||||
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda *_args: None,
|
||||
)
|
||||
def test_trial_tenant_hits_remove_invited_rate_limit(
|
||||
_ee_fetch: MagicMock,
|
||||
_get_tenant_id: MagicMock,
|
||||
_is_trial: MagicMock,
|
||||
_remove_from_invited: MagicMock,
|
||||
mock_rate_limit: MagicMock,
|
||||
) -> None:
|
||||
"""Trial tenants must flow through the remove-invited rate limiter."""
|
||||
remove_invited_user(
|
||||
user_email=UserByEmail(user_email="user@example.com"),
|
||||
current_user=MagicMock(),
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
mock_rate_limit.assert_called_once()
|
||||
|
||||
374
backend/tests/unit/onyx/server/manage/test_invite_rate_limit.py
Normal file
374
backend/tests/unit/onyx/server/manage/test_invite_rate_limit.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Unit tests for the Redis-backed invite + remove-invited rate limits."""
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from redis import Redis
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.invite_rate_limit import enforce_invite_rate_limit
|
||||
from onyx.server.manage.invite_rate_limit import enforce_remove_invited_rate_limit
|
||||
|
||||
|
||||
class _StubRedis:
|
||||
"""In-memory stand-in that mirrors the Lua script's semantics.
|
||||
|
||||
The production rate limiter drives all state via a single EVAL call.
|
||||
The stub reimplements that logic in Python so unit tests can assert
|
||||
behavior without a live Redis — matching semantics including the
|
||||
NX-style TTL that leaves existing TTLs intact on re-increment.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, int] = {}
|
||||
self.ttls: dict[str, int] = {}
|
||||
self.eval_fail: Exception | None = None
|
||||
|
||||
def eval(self, _script: str, num_keys: int, *args: Any) -> int:
|
||||
if self.eval_fail is not None:
|
||||
raise self.eval_fail
|
||||
keys = list(args[:num_keys])
|
||||
argv = list(args[num_keys:])
|
||||
n = int(argv[0])
|
||||
for i in range(n):
|
||||
key = keys[i]
|
||||
increment = int(argv[1 + i * 3])
|
||||
limit = int(argv[2 + i * 3])
|
||||
if limit > 0 and increment > 0:
|
||||
current = self.store.get(key, 0)
|
||||
if current + increment > limit:
|
||||
return i + 1
|
||||
for i in range(n):
|
||||
key = keys[i]
|
||||
increment = int(argv[1 + i * 3])
|
||||
limit = int(argv[2 + i * 3])
|
||||
ttl = int(argv[3 + i * 3])
|
||||
if limit > 0 and increment > 0:
|
||||
self.store[key] = self.store.get(key, 0) + increment
|
||||
if key not in self.ttls:
|
||||
self.ttls[key] = ttl
|
||||
return 0
|
||||
|
||||
|
||||
def _stub() -> Redis:
|
||||
return cast(Redis, _StubRedis())
|
||||
|
||||
|
||||
def test_invite_allows_under_all_tiers() -> None:
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
|
||||
500,
|
||||
),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
stub = cast(_StubRedis, redis_client)
|
||||
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:day"] == 10
|
||||
assert stub.store["ratelimit:invite_put:tenant:tenant_a:day"] == 10
|
||||
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:min"] == 1
|
||||
|
||||
|
||||
def test_invite_minute_bucket_blocks_request_flood() -> None:
|
||||
"""Attacker firing single-email invites rapidly must trip admin/minute."""
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
|
||||
5000,
|
||||
),
|
||||
):
|
||||
for _ in range(5):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.RATE_LIMITED
|
||||
|
||||
|
||||
def test_invite_bulk_call_does_not_trip_minute_bucket() -> None:
|
||||
"""Legitimate one-shot bulk call for many users should not hit minute cap."""
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
|
||||
500,
|
||||
),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=20, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
stub = cast(_StubRedis, redis_client)
|
||||
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:min"] == 1
|
||||
|
||||
|
||||
def test_invite_admin_daily_cap_enforced() -> None:
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN",
|
||||
1000,
|
||||
),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
|
||||
5000,
|
||||
),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=50, tenant_id="tenant_a"
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
|
||||
def test_invite_tenant_daily_cap_enforced_across_admins() -> None:
|
||||
"""Tenant cap should trip even when traffic comes from multiple admins."""
|
||||
redis_client = _stub()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN",
|
||||
1000,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY",
|
||||
1000,
|
||||
),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 10),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, uuid4(), num_invites=6, tenant_id="tenant_a"
|
||||
)
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, uuid4(), num_invites=4, tenant_id="tenant_a"
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, uuid4(), num_invites=1, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
|
||||
def test_invite_rejected_request_does_not_consume_budget() -> None:
|
||||
"""A request that violates a tier must not increment the surviving tiers."""
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 10),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=5, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
stub = cast(_StubRedis, redis_client)
|
||||
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:day"] == 10
|
||||
assert stub.store["ratelimit:invite_put:tenant:tenant_a:day"] == 10
|
||||
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:min"] == 1
|
||||
|
||||
|
||||
def test_invite_zero_new_invites_still_ticks_minute_bucket() -> None:
|
||||
"""Probes against already-invited emails must still tick the burst guard."""
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 2),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
|
||||
5000,
|
||||
),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=0, tenant_id="tenant_a"
|
||||
)
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=0, tenant_id="tenant_a"
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=0, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
stub = cast(_StubRedis, redis_client)
|
||||
assert stub.store.get(f"ratelimit:invite_put:admin:{user_id}:day", 0) == 0
|
||||
assert stub.store.get("ratelimit:invite_put:tenant:tenant_a:day", 0) == 0
|
||||
|
||||
|
||||
def test_invite_limit_zero_disables_tier() -> None:
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 0),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 0),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 0),
|
||||
):
|
||||
for _ in range(100):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
|
||||
def test_invite_tenant_bucket_is_isolated_across_tenants() -> None:
|
||||
"""Regression guard: tenants MUST NOT share the tenant/day counter.
|
||||
TenantRedis does not prefix keys passed to `eval`, so the tenant_id is
|
||||
baked into the key string itself."""
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 1000),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 1000),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 10),
|
||||
):
|
||||
# Tenant A exhausts its own cap.
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
# Tenant B must still have its full budget.
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, uuid4(), num_invites=10, tenant_id="tenant_b"
|
||||
)
|
||||
|
||||
stub = cast(_StubRedis, redis_client)
|
||||
assert stub.store["ratelimit:invite_put:tenant:tenant_a:day"] == 10
|
||||
assert stub.store["ratelimit:invite_put:tenant:tenant_b:day"] == 10
|
||||
|
||||
|
||||
def test_invite_fails_open_when_redis_unavailable() -> None:
|
||||
"""Onyx Lite deployments ship without Redis; invite flow must still work."""
|
||||
stub = _StubRedis()
|
||||
stub.eval_fail = RedisConnectionError("Redis not reachable")
|
||||
redis_client = cast(Redis, stub)
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 1),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 1),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 1),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=1_000_000, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
|
||||
def test_remove_minute_bucket_blocks_pattern_attack() -> None:
|
||||
"""PUT→PATCH spam must trip the remove-invited minute bucket."""
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_MIN",
|
||||
3,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_DAY",
|
||||
100,
|
||||
),
|
||||
):
|
||||
for _ in range(3):
|
||||
enforce_remove_invited_rate_limit(redis_client, user_id)
|
||||
with pytest.raises(OnyxError):
|
||||
enforce_remove_invited_rate_limit(redis_client, user_id)
|
||||
|
||||
|
||||
def test_remove_daily_cap_enforced() -> None:
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_MIN",
|
||||
1000,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_DAY",
|
||||
5,
|
||||
),
|
||||
):
|
||||
for _ in range(5):
|
||||
enforce_remove_invited_rate_limit(redis_client, user_id)
|
||||
with pytest.raises(OnyxError):
|
||||
enforce_remove_invited_rate_limit(redis_client, user_id)
|
||||
|
||||
|
||||
def test_ttls_set_on_first_increment_and_not_reset() -> None:
|
||||
"""TTL must be set on the first increment and must not be reset on later ones."""
|
||||
redis_client = _stub()
|
||||
user_id = uuid4()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 100),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
|
||||
5000,
|
||||
),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=3, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
stub = cast(_StubRedis, redis_client)
|
||||
assert stub.ttls[f"ratelimit:invite_put:admin:{user_id}:day"] == 24 * 60 * 60
|
||||
assert stub.ttls["ratelimit:invite_put:tenant:tenant_a:day"] == 24 * 60 * 60
|
||||
assert stub.ttls[f"ratelimit:invite_put:admin:{user_id}:min"] == 60
|
||||
|
||||
stub.ttls[f"ratelimit:invite_put:admin:{user_id}:min"] = 999
|
||||
with (
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 100),
|
||||
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
|
||||
patch(
|
||||
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
|
||||
5000,
|
||||
),
|
||||
):
|
||||
enforce_invite_rate_limit(
|
||||
redis_client, user_id, num_invites=3, tenant_id="tenant_a"
|
||||
)
|
||||
|
||||
assert stub.ttls[f"ratelimit:invite_put:admin:{user_id}:min"] == 999
|
||||
@@ -0,0 +1,79 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.users import list_all_users_basic_info
|
||||
|
||||
|
||||
def _fake_user(
|
||||
email: str, account_type: AccountType = AccountType.STANDARD
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = email
|
||||
user.account_type = account_type
|
||||
return user
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.USER_DIRECTORY_ADMIN_ONLY", True)
|
||||
def test_list_all_users_basic_info_blocks_non_admin_when_directory_restricted() -> None:
|
||||
"""With the flag on, a caller lacking READ_USERS cannot enumerate the directory."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = [Permission.BASIC_ACCESS.value]
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
list_all_users_basic_info(
|
||||
include_api_keys=False,
|
||||
user=user,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.USER_DIRECTORY_ADMIN_ONLY", True)
|
||||
@patch("onyx.server.manage.users.get_all_users")
|
||||
def test_list_all_users_basic_info_allows_admin_when_directory_restricted(
|
||||
mock_get_all_users: MagicMock,
|
||||
) -> None:
|
||||
"""With the flag on, an admin (FULL_ADMIN_PANEL_ACCESS) still gets the directory."""
|
||||
admin = MagicMock()
|
||||
admin.effective_permissions = [Permission.FULL_ADMIN_PANEL_ACCESS.value]
|
||||
mock_get_all_users.return_value = [_fake_user("a@example.com")]
|
||||
|
||||
result = list_all_users_basic_info(
|
||||
include_api_keys=False,
|
||||
user=admin,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert [u.email for u in result] == ["a@example.com"]
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.USER_DIRECTORY_ADMIN_ONLY", False)
|
||||
@patch("onyx.server.manage.users.get_all_users")
|
||||
def test_list_all_users_basic_info_allows_non_admin_when_flag_off(
|
||||
mock_get_all_users: MagicMock,
|
||||
) -> None:
|
||||
"""With the flag off (default), non-admin callers continue to get the directory."""
|
||||
basic = MagicMock()
|
||||
basic.effective_permissions = [Permission.BASIC_ACCESS.value]
|
||||
mock_get_all_users.return_value = [
|
||||
_fake_user("human@example.com"),
|
||||
_fake_user("bot@example.com", account_type=AccountType.BOT),
|
||||
]
|
||||
|
||||
result = list_all_users_basic_info(
|
||||
include_api_keys=False,
|
||||
user=basic,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
# BOT accounts are filtered out; human account is returned.
|
||||
assert [u.email for u in result] == ["human@example.com"]
|
||||
@@ -40,6 +40,7 @@ services:
|
||||
- SMTP_USER=${SMTP_USER:-}
|
||||
- SMTP_PASS=${SMTP_PASS:-}
|
||||
- ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-}
|
||||
- USER_DIRECTORY_ADMIN_ONLY=${USER_DIRECTORY_ADMIN_ONLY:-}
|
||||
- EMAIL_FROM=${EMAIL_FROM:-}
|
||||
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
|
||||
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
|
||||
|
||||
@@ -49,6 +49,11 @@ SESSION_EXPIRE_TIME_SECONDS=604800
|
||||
#VALID_EMAIL_DOMAINS=
|
||||
|
||||
|
||||
# Set to "true" to restrict GET /users to admins. Non-admins can't enumerate
|
||||
# accounts or share agents with individual users; group sharing still works.
|
||||
#USER_DIRECTORY_ADMIN_ONLY=
|
||||
|
||||
|
||||
# Default values here are what Postgres uses by default, feel free to change.
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=password
|
||||
|
||||
@@ -32,6 +32,9 @@ USER_AUTH_SECRET=""
|
||||
# API_KEY_HASH_ROUNDS=
|
||||
### You can add a comma separated list of domains like onyx.app, only those domains will be allowed to signup/log in
|
||||
# VALID_EMAIL_DOMAINS=
|
||||
### Set to "true" to restrict GET /users to admins. Non-admins can't enumerate
|
||||
### accounts or share agents with individual users; group sharing still works.
|
||||
# USER_DIRECTORY_ADMIN_ONLY=
|
||||
|
||||
## Chat Configuration
|
||||
# HARD_DELETE_CHATS=
|
||||
@@ -172,7 +175,7 @@ LOG_ONYX_MODEL_INTERACTIONS=False
|
||||
|
||||
## Gen AI Settings
|
||||
# GEN_AI_MAX_TOKENS=
|
||||
# LLM_SOCKET_READ_TIMEOUT=
|
||||
LLM_SOCKET_READ_TIMEOUT=120
|
||||
# MAX_CHUNKS_FED_TO_CHAT=
|
||||
# DISABLE_LITELLM_STREAMING=
|
||||
# LITELLM_EXTRA_HEADERS=
|
||||
|
||||
@@ -1262,7 +1262,7 @@ configMap:
|
||||
S3_FILE_STORE_BUCKET_NAME: ""
|
||||
# Gen AI Settings
|
||||
GEN_AI_MAX_TOKENS: ""
|
||||
LLM_SOCKET_READ_TIMEOUT: "60"
|
||||
LLM_SOCKET_READ_TIMEOUT: "120"
|
||||
MAX_CHUNKS_FED_TO_CHAT: ""
|
||||
# Query Options
|
||||
DOC_TIME_DECAY: ""
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
/**
|
||||
* SignInButton — renders the SSO / OAuth sign-in button on the login page.
|
||||
*
|
||||
* When reCAPTCHA is enabled for this deployment (NEXT_PUBLIC_RECAPTCHA_SITE_KEY
|
||||
* set at build time), the Google/OIDC/SAML OAuth click is intercepted to
|
||||
* (1) fetch a reCAPTCHA v3 token for the "oauth" action, (2) POST it to
|
||||
* /api/auth/captcha/oauth-verify which sets a signed HttpOnly cookie on the
|
||||
* response, and (3) then navigate to the authorize URL. The cookie is sent
|
||||
* automatically on the subsequent Google redirect back to our callback,
|
||||
* where the backend middleware verifies it.
|
||||
*
|
||||
* IMPORTANT: This component is rendered as part of the /auth/login page, which
|
||||
* is used in healthcheck and monitoring flows that issue headless (non-browser)
|
||||
* requests (e.g. `curl`). During server-side rendering of those requests,
|
||||
@@ -18,10 +26,13 @@
|
||||
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { AuthType } from "@/lib/constants";
|
||||
import { FcGoogle } from "react-icons/fc";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import { useCaptcha } from "@/lib/hooks/useCaptcha";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
interface SignInButtonProps {
|
||||
authorizeUrl: string;
|
||||
@@ -32,6 +43,10 @@ export default function SignInButton({
|
||||
authorizeUrl,
|
||||
authType,
|
||||
}: SignInButtonProps) {
|
||||
const { getCaptchaToken, isCaptchaEnabled } = useCaptcha();
|
||||
const [isVerifying, setIsVerifying] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
let button: string | undefined;
|
||||
let icon: React.FunctionComponent<IconProps> | undefined;
|
||||
|
||||
@@ -48,18 +63,84 @@ export default function SignInButton({
|
||||
throw new Error(`Unhandled authType: ${authType}`);
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
prominence={
|
||||
authType === AuthType.GOOGLE_OAUTH || authType === AuthType.CLOUD
|
||||
? "secondary"
|
||||
: "primary"
|
||||
async function handleClick(e: React.MouseEvent) {
|
||||
e.preventDefault();
|
||||
if (isVerifying) return;
|
||||
setIsVerifying(true);
|
||||
setError(null);
|
||||
// Stays true on the success branch so the button remains disabled until
|
||||
// the browser actually begins unloading for the OAuth redirect — prevents
|
||||
// a double-click window between `window.location.href = ...` and unload.
|
||||
let navigating = false;
|
||||
try {
|
||||
const token = await getCaptchaToken("oauth");
|
||||
if (!token) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(
|
||||
"Captcha: grecaptcha.execute returned no token. The widget may not have loaded yet."
|
||||
);
|
||||
setError("grecaptcha.execute returned no token");
|
||||
return;
|
||||
}
|
||||
width="full"
|
||||
icon={icon}
|
||||
href={authorizeUrl}
|
||||
>
|
||||
{button}
|
||||
</Button>
|
||||
const res = await fetch("/api/auth/captcha/oauth-verify", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ token }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const body = await res.json().catch(() => ({}));
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(
|
||||
`Captcha verify rejected: status=${res.status} detail=${
|
||||
body.detail ?? "(none)"
|
||||
}`
|
||||
);
|
||||
setError(
|
||||
"Captcha verification failed. Please refresh your browser and try again."
|
||||
);
|
||||
return;
|
||||
}
|
||||
navigating = true;
|
||||
window.location.href = authorizeUrl;
|
||||
} catch (exc) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.error("Captcha verify request failed", exc);
|
||||
setError(exc instanceof Error ? exc.message : String(exc));
|
||||
} finally {
|
||||
if (!navigating) setIsVerifying(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Only the Google OAuth callback is gated by CaptchaCookieMiddleware on the
|
||||
// backend. OIDC/SAML callbacks have no cookie requirement, so running the
|
||||
// reCAPTCHA interception for them is wasted friction — and worse, a failed
|
||||
// captcha would block the sign-in entirely.
|
||||
const intercepted =
|
||||
isCaptchaEnabled &&
|
||||
(authType === AuthType.GOOGLE_OAUTH || authType === AuthType.CLOUD);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
prominence={
|
||||
authType === AuthType.GOOGLE_OAUTH || authType === AuthType.CLOUD
|
||||
? "secondary"
|
||||
: "primary"
|
||||
}
|
||||
width="full"
|
||||
icon={icon}
|
||||
href={intercepted ? undefined : authorizeUrl}
|
||||
onClick={intercepted ? handleClick : undefined}
|
||||
disabled={isVerifying}
|
||||
>
|
||||
{button}
|
||||
</Button>
|
||||
{error && (
|
||||
<Text as="p" mainUiMuted className="text-status-error-05 mt-2">
|
||||
{error}
|
||||
</Text>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -27,8 +27,10 @@ import { useUser } from "@/providers/UserProvider";
|
||||
import { Formik, useFormikContext } from "formik";
|
||||
import { useAgent } from "@/hooks/useAgents";
|
||||
import { Button, MessageCard } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { useLabels } from "@/lib/hooks";
|
||||
import { PersonaLabel } from "@/app/admin/agents/interfaces";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
|
||||
const YOUR_ORGANIZATION_TAB = "Your Organization";
|
||||
const USERS_AND_GROUPS_TAB = "Users & Groups";
|
||||
@@ -56,8 +58,12 @@ interface ShareAgentFormContentProps {
|
||||
function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
const { values, setFieldValue, handleSubmit, dirty, isSubmitting } =
|
||||
useFormikContext<ShareAgentFormValues>();
|
||||
const { data: usersData } = useShareableUsers({ includeApiKeys: true });
|
||||
const { data: usersData, error: usersError } = useShareableUsers({
|
||||
includeApiKeys: true,
|
||||
});
|
||||
const { data: groupsData } = useShareableGroups();
|
||||
const userDirectoryRestricted =
|
||||
usersError instanceof FetchError && usersError.status === 403;
|
||||
const { user: currentUser, isAdmin, isCurator } = useUser();
|
||||
const { agent: fullAgent } = useAgent(agentId ?? null);
|
||||
const shareAgentModal = useModal();
|
||||
@@ -70,12 +76,14 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
|
||||
// Create options for InputComboBox from all accepted users and groups
|
||||
const comboBoxOptions = useMemo(() => {
|
||||
const userOptions = acceptedUsers
|
||||
.filter((user) => user.id !== currentUser?.id)
|
||||
.map((user) => ({
|
||||
value: `user-${user.id}`,
|
||||
label: user.email,
|
||||
}));
|
||||
const userOptions = userDirectoryRestricted
|
||||
? []
|
||||
: acceptedUsers
|
||||
.filter((user) => user.id !== currentUser?.id)
|
||||
.map((user) => ({
|
||||
value: `user-${user.id}`,
|
||||
label: user.email,
|
||||
}));
|
||||
|
||||
const groupOptions = groups.map((group) => ({
|
||||
value: `group-${group.id}`,
|
||||
@@ -83,7 +91,10 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
}));
|
||||
|
||||
return [...userOptions, ...groupOptions];
|
||||
}, [acceptedUsers, groups, currentUser?.id]);
|
||||
}, [acceptedUsers, groups, currentUser?.id, userDirectoryRestricted]);
|
||||
|
||||
const comboBoxDisabled =
|
||||
userDirectoryRestricted && comboBoxOptions.length === 0;
|
||||
|
||||
// Compute owner and displayed users
|
||||
const ownerId = fullAgent?.owner?.id;
|
||||
@@ -214,14 +225,31 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
|
||||
<Tabs.Content value={USERS_AND_GROUPS_TAB}>
|
||||
<Section gap={0.5} alignItems="start">
|
||||
<InputComboBox
|
||||
placeholder="Add users and groups"
|
||||
value=""
|
||||
onChange={() => {}}
|
||||
onValueChange={handleComboBoxSelect}
|
||||
options={comboBoxOptions}
|
||||
strict
|
||||
/>
|
||||
<Disabled
|
||||
disabled={comboBoxDisabled}
|
||||
tooltip={
|
||||
comboBoxDisabled
|
||||
? "Your administrator has restricted the user directory. Contact an admin to share this agent with other users."
|
||||
: undefined
|
||||
}
|
||||
tooltipSide="bottom"
|
||||
>
|
||||
<div className="w-full">
|
||||
<InputComboBox
|
||||
placeholder={
|
||||
userDirectoryRestricted
|
||||
? "Add groups"
|
||||
: "Add users and groups"
|
||||
}
|
||||
value=""
|
||||
onChange={() => {}}
|
||||
onValueChange={handleComboBoxSelect}
|
||||
options={comboBoxOptions}
|
||||
strict
|
||||
disabled={comboBoxDisabled}
|
||||
/>
|
||||
</div>
|
||||
</Disabled>
|
||||
{(displayedUsers.length > 0 || displayedGroups.length > 0) && (
|
||||
<Section gap={0} alignItems="stretch">
|
||||
{/* Shared Users */}
|
||||
|
||||
Reference in New Issue
Block a user