Compare commits

...

9 Commits

30 changed files with 2051 additions and 84 deletions

View File

@@ -1,17 +1,44 @@
"""Captcha verification for user registration."""
"""Captcha verification for user registration.
Two flows share this module:
1. Email/password signup. The token is posted with the signup body and
verified inline by ``UserManager.create``.
2. Google OAuth signup. The OAuth callback request originates from Google
as a browser redirect, so we cannot attach a header or body field to it
at that moment. Instead the frontend verifies a reCAPTCHA token BEFORE
redirecting to Google and we set a signed HttpOnly cookie. The cookie
is sent automatically on the callback request, where middleware checks
it. ``issue_captcha_cookie_value`` / ``validate_captcha_cookie_value``
handle the HMAC signing + expiry.
"""
import hashlib
import hmac
import time
import httpx
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import CAPTCHA_COOKIE_TTL_SECONDS
from onyx.configs.app_configs import CAPTCHA_ENABLED
from onyx.configs.app_configs import RECAPTCHA_SCORE_THRESHOLD
from onyx.configs.app_configs import RECAPTCHA_SECRET_KEY
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.utils.logger import setup_logger
logger = setup_logger()
RECAPTCHA_VERIFY_URL = "https://www.google.com/recaptcha/api/siteverify"
CAPTCHA_COOKIE_NAME = "onyx_captcha_verified"
# Google v3 tokens expire server-side at ~2 minutes, so 120s is the max useful
# replay window — after that Google would reject the token anyway.
_REPLAY_CACHE_TTL_SECONDS = 120
_REPLAY_KEY_PREFIX = "captcha:replay:"
class CaptchaVerificationError(Exception):
@@ -34,6 +61,51 @@ def is_captcha_enabled() -> bool:
return CAPTCHA_ENABLED and bool(RECAPTCHA_SECRET_KEY)
def _replay_cache_key(token: str) -> str:
"""Avoid storing the raw token in Redis — hash it first."""
digest = hashlib.sha256(token.encode("utf-8")).hexdigest()
return f"{_REPLAY_KEY_PREFIX}{digest}"
async def _reserve_token_or_raise(token: str) -> None:
"""SETNX a token fingerprint. If another caller already claimed it within
the TTL, reject as a replay. Fails open on Redis errors — losing replay
protection is strictly better than hard-failing legitimate registrations
if Redis blips."""
try:
redis = await get_async_redis_connection()
claimed = await redis.set(
_replay_cache_key(token),
"1",
nx=True,
ex=_REPLAY_CACHE_TTL_SECONDS,
)
if not claimed:
logger.warning("Captcha replay detected: token already used")
raise CaptchaVerificationError(
"Captcha verification failed: token already used"
)
except CaptchaVerificationError:
raise
except Exception as e:
logger.error(f"Captcha replay cache error (failing open): {e}")
async def _release_token(token: str) -> None:
"""Unclaim a previously-reserved token so a retry with the same still-valid
token is not blocked. Called when WE fail (network error talking to
Google), not when Google rejects the token — Google rejections mean the
token is permanently invalid and must stay claimed."""
try:
redis = await get_async_redis_connection()
await redis.delete(_replay_cache_key(token))
except Exception as e:
# Worst case: the user must wait up to 120s before the TTL expires
# on its own and they can retry. Still strictly better than failing
# open on the reservation side.
logger.error(f"Captcha replay cache release error (ignored): {e}")
async def verify_captcha_token(
token: str,
expected_action: str = "signup",
@@ -54,6 +126,11 @@ async def verify_captcha_token(
if not token:
raise CaptchaVerificationError("Captcha token is required")
# Claim the token first so a concurrent replay of the same value cannot
# slip through the Google round-trip window. Done BEFORE calling Google
# because even a still-valid token should only redeem once.
await _reserve_token_or_raise(token)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
@@ -76,31 +153,113 @@ async def verify_captcha_token(
f"Captcha verification failed: {', '.join(error_codes)}"
)
# For reCAPTCHA v3, also check the score
if result.score is not None:
if result.score < RECAPTCHA_SCORE_THRESHOLD:
logger.warning(
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
)
raise CaptchaVerificationError(
"Captcha verification failed: suspicious activity detected"
)
# Require v3 score. Google's public test secret returns no score
# — that path must not be active in prod since it skips the only
# human-vs-bot signal. A missing score here means captcha is
# misconfigured (test secret in prod, or a v2 response slipped in
# via an action mismatch).
if result.score is None:
logger.warning(
"Captcha verification failed: siteverify returned no score (likely test secret in prod)"
)
raise CaptchaVerificationError(
"Captcha verification failed: missing score"
)
# Optionally verify the action matches
if result.action and result.action != expected_action:
logger.warning(
f"Captcha action mismatch: {result.action} != {expected_action}"
)
raise CaptchaVerificationError(
"Captcha verification failed: action mismatch"
)
if result.score < RECAPTCHA_SCORE_THRESHOLD:
logger.warning(
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
)
raise CaptchaVerificationError(
"Captcha verification failed: suspicious activity detected"
)
if result.action and result.action != expected_action:
logger.warning(
f"Captcha action mismatch: {result.action} != {expected_action}"
)
raise CaptchaVerificationError(
"Captcha verification failed: action mismatch"
)
logger.debug(
f"Captcha verification passed: score={result.score}, action={result.action}"
)
except httpx.HTTPError as e:
logger.error(f"Captcha API request failed: {e}")
# In case of API errors, we might want to allow registration
# to prevent blocking legitimate users. This is a policy decision.
except CaptchaVerificationError:
# Definitively-bad token (Google rejected it, score too low, action
# mismatch). Keep the reservation so the same token cannot be
# retried elsewhere during the TTL window.
raise
except Exception as e:
# Anything else — network failure, JSON decode error, Pydantic
# validation error on an unexpected siteverify response shape — is
# OUR inability to verify the token, not proof the token is bad.
# Release the reservation so the user can retry with the same
# still-valid token instead of being locked out for ~120s.
logger.error(f"Captcha verification failed unexpectedly: {e}")
await _release_token(token)
raise CaptchaVerificationError("Captcha verification service unavailable")
# ---------------------------------------------------------------------------
# OAuth pre-redirect cookie helpers
# ---------------------------------------------------------------------------
def _cookie_signing_key() -> bytes:
"""Derive a dedicated HMAC key from USER_AUTH_SECRET.
Using a separate derivation keeps the captcha cookie signature from
being interchangeable with any other token that reuses USER_AUTH_SECRET.
"""
return hashlib.sha256(
f"onyx-captcha-cookie-v1::{USER_AUTH_SECRET}".encode("utf-8")
).digest()
def issue_captcha_cookie_value(now: int | None = None) -> str:
"""Produce an opaque cookie value encoding 'verified until <expiry>'.
Format: ``<expiry_epoch>.<hex_hmac>``. The presence of a valid
unexpired signature proves the browser solved a captcha challenge
recently on this origin.
"""
issued_at = now if now is not None else int(time.time())
expiry = issued_at + CAPTCHA_COOKIE_TTL_SECONDS
sig = hmac.new(
_cookie_signing_key(), str(expiry).encode("utf-8"), hashlib.sha256
).hexdigest()
return f"{expiry}.{sig}"
def validate_captcha_cookie_value(value: str | None) -> bool:
"""Return True if the cookie value has a valid unexpired signature.
The cookie is NOT a JWT — it's a minimal two-field format produced by
``issue_captcha_cookie_value``:
<expiry_epoch_seconds>.<hex_hmac_sha256>
We split on the first ``.``, parse the expiry as an integer, recompute
the HMAC over the expiry string using the key derived from
USER_AUTH_SECRET, and compare with ``hmac.compare_digest`` to avoid
timing leaks. No base64, no JSON, no claims — anything fancier would
be overkill for a short-lived "verified recently" cookie.
"""
if not value:
return False
parts = value.split(".", 1)
if len(parts) != 2:
return False
expiry_str, provided_sig = parts
try:
expiry = int(expiry_str)
except ValueError:
return False
if expiry < int(time.time()):
return False
expected_sig = hmac.new(
_cookie_signing_key(), str(expiry).encode("utf-8"), hashlib.sha256
).hexdigest()
return hmac.compare_digest(expected_sig, provided_sig)

View File

@@ -100,6 +100,7 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.multi_llm import LLMTimeoutError
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
@@ -1277,6 +1278,32 @@ def _run_models(
else:
if item is _MODEL_DONE:
models_remaining -= 1
elif isinstance(item, LLMTimeoutError):
model_llm = setup.llms[model_idx]
error_msg = (
"The LLM took too long to respond. "
"If you're running a local model, try increasing the "
"LLM_SOCKET_READ_TIMEOUT environment variable "
"(current default: 120 seconds)."
)
stack_trace = "".join(
traceback.format_exception(type(item), item, item.__traceback__)
)
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="CONNECTION_ERROR",
is_retryable=True,
details={
"model": model_llm.config.model_name,
"provider": model_llm.config.model_provider,
"model_index": model_idx,
},
)
elif isinstance(item, Exception):
# Yield a tagged error for this model but keep the other models running.
# Do NOT decrement models_remaining — _run_model's finally always posts

View File

@@ -180,6 +180,13 @@ DISPOSABLE_EMAIL_DOMAINS_URL = os.environ.get(
"https://disposable.github.io/disposable-email-domains/domains.json",
)
# Captcha cookie TTL — how long a verified captcha token remains valid in
# the browser cookie before the user has to solve another challenge. Sized
# to comfortably cover one Google OAuth round-trip (typically <10s) while
# keeping the replay window tight. 120s also matches Google's own v3 token
# lifetime, so a paired-up cookie + token never outlive each other.
CAPTCHA_COOKIE_TTL_SECONDS = int(os.environ.get("CAPTCHA_COOKIE_TTL_SECONDS", "120"))
# OAuth Login Flow
# Used for both Google OAuth2 and OIDC flows
OAUTH_CLIENT_ID = (
@@ -1069,6 +1076,12 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# When true, GET /users is restricted to callers with READ_USERS so non-admins
# cannot enumerate the tenant directory. Off by default to preserve sharing UX.
USER_DIRECTORY_ADMIN_ONLY = (
os.environ.get("USER_DIRECTORY_ADMIN_ONLY", "").lower() == "true"
)
# Limit on number of users a free trial tenant can invite (cloud only)
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))

View File

@@ -519,18 +519,45 @@ class OpenSearchIndexClient(OpenSearchClient):
logger.debug(f"Settings of index {self._index_name} updated successfully.")
@log_function_time(print_only=True, debug_only=True)
def get_settings(self) -> dict[str, Any]:
def get_settings(
self,
include_defaults: bool = False,
flat_settings: bool = False,
pretty: bool = False,
human: bool = False,
) -> tuple[dict[str, Any], dict[str, Any] | None]:
"""Gets the settings of the index.
Args:
include_defaults: Whether to include default settings which have not
been explicitly set. Defaults to False.
flat_settings: Whether to return settings in flat format vs nested
dictionaries. Defaults to False.
pretty: Whether to pretty-format the returned JSON response.
Defaults to False.
human: Whether to return statistics in human-readable format.
Defaults to False.
Returns:
The settings of the index.
The settings of the index, and optionally the default settings. If
include_defaults is False, the default settings will be None.
Raises:
Exception: There was an error getting the settings of the index.
"""
logger.debug(f"Getting settings of index {self._index_name}.")
response = self._client.indices.get_settings(index=self._index_name)
return response[self._index_name]["settings"]
params = {
"include_defaults": str(include_defaults).lower(),
"flat_settings": str(flat_settings).lower(),
"pretty": str(pretty).lower(),
"human": str(human).lower(),
}
response = self._client.indices.get_settings(
index=self._index_name, params=params
)
return response[self._index_name]["settings"], response[self._index_name].get(
"defaults", None
)
@log_function_time(print_only=True, debug_only=True)
def open_index(self) -> None:

View File

@@ -290,7 +290,11 @@ def litellm_exception_to_error_msg(
error_code = "BUDGET_EXCEEDED"
is_retryable = False
elif isinstance(core_exception, Timeout):
error_msg = "Request timed out: The operation took too long to complete. Please try again."
error_msg = (
"The LLM took too long to respond. "
"If you're running a local model, try increasing the "
"LLM_SOCKET_READ_TIMEOUT environment variable (current default: 120 seconds)."
)
error_code = "CONNECTION_ERROR"
is_retryable = True
elif isinstance(core_exception, APIError):
@@ -743,7 +747,13 @@ def model_is_reasoning_model(model_name: str, model_provider: str) -> bool:
model_name,
)
if model_obj and "supports_reasoning" in model_obj:
return model_obj["supports_reasoning"]
reasoning = model_obj["supports_reasoning"]
if reasoning is None:
logger.error(
f"Cannot find reasoning for name={model_name} and provider={model_provider}"
)
reasoning = False
return reasoning
# Fallback: try using litellm.supports_reasoning() for newer models
try:

View File

@@ -64,6 +64,8 @@ from onyx.error_handling.exceptions import register_onyx_exception_handlers
from onyx.file_store.file_store import get_default_file_store
from onyx.hooks.registry import validate_registry
from onyx.server.api_key.api import router as api_key_router
from onyx.server.auth.captcha_api import CaptchaCookieMiddleware
from onyx.server.auth.captcha_api import router as captcha_router
from onyx.server.auth_check import check_router_auth
from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
@@ -524,6 +526,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, mcp_admin_router)
include_router_with_global_prefix_prepended(application, pat_router)
include_router_with_global_prefix_prepended(application, captcha_router)
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_auth_router_with_prefix(
@@ -655,6 +658,10 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
allow_methods=["*"],
allow_headers=["*"],
)
# Gate the OAuth callback on a signed captcha cookie set by the frontend
# before the Google redirect. No-op unless is_captcha_enabled() is true
# (requires CAPTCHA_ENABLED=true and RECAPTCHA_SECRET_KEY set).
application.add_middleware(CaptchaCookieMiddleware)
if LOG_ENDPOINT_LATENCY:
add_latency_logging_middleware(application, logger)

View File

View File

@@ -0,0 +1,118 @@
"""API + middleware for the reCAPTCHA pre-OAuth cookie flow.
The frontend solves a reCAPTCHA v3 challenge before clicking "Continue
with Google", POSTs the token to ``/auth/captcha/oauth-verify``, and the
backend verifies it with Google and sets a signed HttpOnly cookie. The
cookie rides along on the subsequent Google OAuth callback redirect,
where ``CaptchaCookieMiddleware`` checks it. Without this cookie flow
the OAuth callback is un-gated because Google (not our frontend) issues
the request and we cannot attach a header at the redirect hop.
Email/password signup has its own captcha enforcement inside
``UserManager.create``, so this module only gates the OAuth callback.
"""
from fastapi import APIRouter
from fastapi import Request
from fastapi import Response
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
from onyx.auth.captcha import CaptchaVerificationError
from onyx.auth.captcha import is_captcha_enabled
from onyx.auth.captcha import issue_captcha_cookie_value
from onyx.auth.captcha import validate_captcha_cookie_value
from onyx.auth.captcha import verify_captcha_token
from onyx.configs.app_configs import CAPTCHA_COOKIE_TTL_SECONDS
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import onyx_error_to_json_response
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/auth/captcha", tags=PUBLIC_API_TAGS)
# Only the OAuth callback is gated here. /auth/register has its own
# captcha enforcement inside UserManager.create via the body's
# captcha_token field — the cookie layer is specifically for the OAuth
# redirect that our frontend cannot attach a header to.
GUARDED_OAUTH_CALLBACK_PATHS = frozenset({"/auth/oauth/callback"})
class OAuthCaptchaVerifyRequest(BaseModel):
token: str
class OAuthCaptchaVerifyResponse(BaseModel):
ok: bool
@router.post("/oauth-verify")
async def verify_oauth_captcha(
body: OAuthCaptchaVerifyRequest,
response: Response,
) -> OAuthCaptchaVerifyResponse:
"""Verify a reCAPTCHA token and set the OAuth-redirect cookie.
If captcha enforcement is off the endpoint is a no-op so the frontend
doesn't block on dormant deployments.
"""
if not is_captcha_enabled():
return OAuthCaptchaVerifyResponse(ok=True)
try:
await verify_captcha_token(body.token, expected_action="oauth")
except CaptchaVerificationError as exc:
raise OnyxError(OnyxErrorCode.UNAUTHORIZED, str(exc))
response.set_cookie(
key=CAPTCHA_COOKIE_NAME,
value=issue_captcha_cookie_value(),
max_age=CAPTCHA_COOKIE_TTL_SECONDS,
secure=True,
httponly=True,
samesite="lax",
path="/",
)
return OAuthCaptchaVerifyResponse(ok=True)
class CaptchaCookieMiddleware(BaseHTTPMiddleware):
"""Reject OAuth-callback requests that don't carry a valid captcha cookie.
No-op when ``is_captcha_enabled()`` is false so self-hosted and dev
deployments pass through transparently.
"""
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
# Skip OPTIONS so CORS preflight is never blocked.
is_guarded_callback = (
request.method != "OPTIONS"
and request.url.path in GUARDED_OAUTH_CALLBACK_PATHS
and is_captcha_enabled()
)
if is_guarded_callback:
cookie_value = request.cookies.get(CAPTCHA_COOKIE_NAME)
if not validate_captcha_cookie_value(cookie_value):
return onyx_error_to_json_response(
OnyxError(
OnyxErrorCode.UNAUTHORIZED,
"Captcha challenge required. Refresh the page and try again.",
)
)
response = await call_next(request)
# One-time-use cookie: after the OAuth callback has been served, clear
# it so the remaining TTL cannot be replayed (e.g. via browser
# back-button) to re-enter the callback without a fresh challenge.
if is_guarded_callback:
response.delete_cookie(CAPTCHA_COOKIE_NAME, path="/")
return response

View File

@@ -13,7 +13,6 @@ from onyx.auth.users import current_user_with_expired_token
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
PUBLIC_ENDPOINT_SPECS = [
# built-in documentation functions
("/openapi.json", {"GET", "HEAD"}),
@@ -35,6 +34,10 @@ PUBLIC_ENDPOINT_SPECS = [
("/auth/refresh", {"POST"}),
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
# reCAPTCHA pre-OAuth challenge — user is not yet authenticated when
# they solve it, and the endpoint's own handler enforces the only
# thing that matters (valid Google siteverify response).
("/auth/captcha/oauth-verify", {"POST"}),
("/auth/logout", {"POST"}),
("/auth/forgot-password", {"POST"}),
("/auth/reset-password", {"POST"}),

View File

@@ -0,0 +1,236 @@
"""Redis-backed rate limits for admin invite + remove-invited-user endpoints.
Defends against compromised-admin invite-spam and email-bomb abuse that
nginx IP-keyed `limit_req` cannot stop (per-pod counters in multi-replica
deployments, trivial IP rotation). Counters live in tenant-prefixed Redis
so multi-pod api-server instances share state and per-admin / per-tenant
quotas are enforced cluster-wide.
Check+increment is performed in a single Redis Lua script so two
concurrent replicas cannot both pass the pre-check and both increment
past the limit. When Redis is unavailable (e.g. Onyx Lite deployments
where Redis is an opt-in `--profile redis` service), the rate limiter
fails open with a logged warning so core invite flows continue to work.
"""
from dataclasses import dataclass
from uuid import UUID
from redis import Redis
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError
from redis.exceptions import TimeoutError as RedisTimeoutError
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
_SECONDS_PER_MINUTE = 60
_SECONDS_PER_DAY = 24 * 60 * 60
# Rate limits apply to trial tenants only (enforced at call site). Paid
# tenants bypass entirely — their guardrails are seat limits and the
# per-admin lifetime counter. Self-hosted / Lite deployments fail open
# when Redis is unavailable. Values are sized against the trial lifetime
# cap `NUM_FREE_TRIAL_USER_INVITES=10` and the invite→remove→invite
# bypass attack: per-day caps stay tight so a compromised or scripted
# trial admin cannot exceed the lifetime cap even across window rolls,
# and per-minute caps block burst automation while leaving headroom for
# a human typing emails by hand.
_INVITE_ADMIN_PER_MIN = 3
_INVITE_ADMIN_PER_DAY = 10
_INVITE_TENANT_PER_DAY = 15
_REMOVE_ADMIN_PER_MIN = 3
_REMOVE_ADMIN_PER_DAY = 30
# Per-admin buckets are accidentally safe without an explicit tenant
# prefix because admin UUIDs are globally unique. The tenant/day bucket
# MUST embed the tenant_id directly: TenantRedis.__getattribute__ only
# prefixes commands in an explicit allowlist and `eval` is not on it, so
# keys passed to the Lua script reach Redis bare. A key of
# "ratelimit:invite_put:tenant:day" would be shared across every trial
# tenant, exhausting one global counter for the whole cluster.
_INVITE_PUT_ADMIN_MIN_KEY = "ratelimit:invite_put:admin:{user_id}:min"
_INVITE_PUT_ADMIN_DAY_KEY = "ratelimit:invite_put:admin:{user_id}:day"
_INVITE_PUT_TENANT_DAY_KEY = "ratelimit:invite_put:tenant:{tenant_id}:day"
_INVITE_REMOVE_ADMIN_MIN_KEY = "ratelimit:invite_remove:admin:{user_id}:min"
_INVITE_REMOVE_ADMIN_DAY_KEY = "ratelimit:invite_remove:admin:{user_id}:day"
# Atomic multi-bucket check+increment.
# ARGV[1] = N (bucket count). For each bucket i=1..N, ARGV[2+(i-1)*3..4+(i-1)*3]
# carry increment, limit, ttl. KEYS[i] is the bucket's Redis key.
#
# Buckets with limit <= 0 or increment <= 0 are skipped (a disabled tier).
# On reject, returns the 1-indexed bucket number that failed so the caller
# can report which scope tripped; on success returns 0. TTL is set with NX
# semantics so pre-existing keys without a TTL are still given one, but
# fresh increments do not reset the window (fixed-window, not sliding).
_CHECK_AND_INCREMENT_SCRIPT = """
local n = tonumber(ARGV[1])
for i = 1, n do
local key = KEYS[i]
local increment = tonumber(ARGV[2 + (i - 1) * 3])
local limit = tonumber(ARGV[3 + (i - 1) * 3])
if limit > 0 and increment > 0 then
local current = tonumber(redis.call('get', key)) or 0
if current + increment > limit then
return i
end
end
end
for i = 1, n do
local key = KEYS[i]
local increment = tonumber(ARGV[2 + (i - 1) * 3])
local limit = tonumber(ARGV[3 + (i - 1) * 3])
local ttl = tonumber(ARGV[4 + (i - 1) * 3])
if limit > 0 and increment > 0 then
redis.call('incrby', key, increment)
redis.call('expire', key, ttl, 'NX')
end
end
return 0
"""
@dataclass(frozen=True)
class _Bucket:
key: str
limit: int
ttl_seconds: int
scope: str
increment: int
def _run_atomic(redis_client: Redis, buckets: list[_Bucket]) -> None:
"""Run the check+increment Lua script. Raise OnyxError on rejection.
On Redis connection / timeout errors the rate limiter fails open: the
request is allowed through and the failure is logged. This keeps the
invite flow usable on Onyx Lite deployments (Redis is opt-in there)
and during transient Redis outages in full deployments.
"""
if not buckets:
return
keys = [b.key for b in buckets]
argv: list[str] = [str(len(buckets))]
for b in buckets:
argv.extend([str(b.increment), str(b.limit), str(b.ttl_seconds)])
try:
result = redis_client.eval( # type: ignore[no-untyped-call]
_CHECK_AND_INCREMENT_SCRIPT,
len(keys),
*keys,
*argv,
)
except (RedisConnectionError, RedisTimeoutError) as e:
logger.warning(
"Invite rate limiter skipped — Redis unavailable: %s. Rate limiting is disabled for this request.",
e,
)
return
except RedisError as e:
logger.error(
"Invite rate limiter Redis error, failing open: %s",
e,
)
return
failed_index = int(result) if isinstance(result, (int, str, bytes)) else 0
if failed_index <= 0:
return
failed_bucket = buckets[failed_index - 1]
logger.warning(
"Invite rate limit hit: scope=%s key=%s adding=%d limit=%d",
failed_bucket.scope,
failed_bucket.key,
failed_bucket.increment,
failed_bucket.limit,
)
raise OnyxError(
OnyxErrorCode.RATE_LIMITED,
f"Invite rate limit exceeded ({failed_bucket.scope}). Try again later.",
)
def enforce_invite_rate_limit(
redis_client: Redis,
admin_user_id: UUID | str,
num_invites: int,
tenant_id: str,
) -> None:
"""Check+record invite quotas for an admin user within their tenant.
Three tiers. Daily tiers track invite volume (so bulk invite of 20
users counts as 20); the minute tier tracks request cadence (so a
single legitimate bulk call does not trip the burst guard while an
attacker spamming single-email requests does).
Raises OnyxError(RATE_LIMITED) without recording if any tier would be
exceeded, so repeated rejected attempts do not consume budget.
`num_invites` MUST be the count of new invites the request will send
(not total emails in the body — deduplicate already-invited first).
Zero-invite calls still tick the minute bucket so probe-floods of
already-invited emails cannot bypass the burst guard.
"""
user_key = str(admin_user_id)
daily_increment = max(0, num_invites)
buckets = [
_Bucket(
key=_INVITE_PUT_TENANT_DAY_KEY.format(tenant_id=tenant_id),
limit=_INVITE_TENANT_PER_DAY,
ttl_seconds=_SECONDS_PER_DAY,
scope="tenant/day",
increment=daily_increment,
),
_Bucket(
key=_INVITE_PUT_ADMIN_DAY_KEY.format(user_id=user_key),
limit=_INVITE_ADMIN_PER_DAY,
ttl_seconds=_SECONDS_PER_DAY,
scope="admin/day",
increment=daily_increment,
),
_Bucket(
key=_INVITE_PUT_ADMIN_MIN_KEY.format(user_id=user_key),
limit=_INVITE_ADMIN_PER_MIN,
ttl_seconds=_SECONDS_PER_MINUTE,
scope="admin/minute",
increment=1,
),
]
_run_atomic(redis_client, buckets)
def enforce_remove_invited_rate_limit(
redis_client: Redis,
admin_user_id: UUID | str,
) -> None:
"""Check+record remove-invited-user quotas for an admin user.
Two tiers: per-admin per-day and per-admin per-minute. Removal itself
does not send email, so there is no tenant-wide cap — the goal is to
detect the PUT→PATCH abuse pattern by throttling PATCHes to roughly
the cadence of legitimate administrative mistake correction.
"""
user_key = str(admin_user_id)
buckets = [
_Bucket(
key=_INVITE_REMOVE_ADMIN_DAY_KEY.format(user_id=user_key),
limit=_REMOVE_ADMIN_PER_DAY,
ttl_seconds=_SECONDS_PER_DAY,
scope="admin/day",
increment=1,
),
_Bucket(
key=_INVITE_REMOVE_ADMIN_MIN_KEY.format(user_id=user_key),
limit=_REMOVE_ADMIN_PER_MIN,
ttl_seconds=_SECONDS_PER_MINUTE,
scope="admin/minute",
increment=1,
),
]
_run_atomic(redis_client, buckets)

View File

@@ -30,7 +30,6 @@ from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
if TYPE_CHECKING:
pass

View File

@@ -44,6 +44,7 @@ from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import USER_DIRECTORY_ADMIN_ONLY
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.configs.constants import PUBLIC_API_TAGS
@@ -81,10 +82,15 @@ from onyx.db.users import get_total_filtered_users_count
from onyx.db.users import get_user_by_email
from onyx.db.users import get_user_counts_by_role_and_status
from onyx.db.users import validate_user_role_update
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.key_value_store.factory import get_kv_store
from onyx.redis.redis_pool import get_raw_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import PaginatedReturn
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.manage.invite_rate_limit import enforce_invite_rate_limit
from onyx.server.manage.invite_rate_limit import enforce_remove_invited_rate_limit
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import BulkInviteResponse
@@ -469,6 +475,12 @@ def bulk_invite_users(
status_code=403,
detail="You have hit your invite limit. Please upgrade for unlimited invites.",
)
enforce_invite_rate_limit(
redis_client=get_redis_client(tenant_id=tenant_id),
admin_user_id=current_user.id,
num_invites=len(emails_needing_seats),
tenant_id=tenant_id,
)
# Check seat availability for new users
if emails_needing_seats:
@@ -529,10 +541,17 @@ def bulk_invite_users(
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)
def remove_invited_user(
user_email: UserByEmail,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
current_user: User = Depends(
require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)
),
db_session: Session = Depends(get_session),
) -> int:
tenant_id = get_current_tenant_id()
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
enforce_remove_invited_rate_limit(
redis_client=get_redis_client(tenant_id=tenant_id),
admin_user_id=current_user.id,
)
if MULTI_TENANT:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
@@ -672,15 +691,24 @@ def get_valid_domains(
@router.get("/users", tags=PUBLIC_API_TAGS)
def list_all_users_basic_info(
include_api_keys: bool = False,
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> list[MinimalUserSnapshot]:
if (
USER_DIRECTORY_ADMIN_ONLY
and Permission.READ_USERS not in get_effective_permissions(user)
):
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You do not have the required permissions for this action.",
)
users = get_all_users(db_session)
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if user.account_type != AccountType.BOT
and (include_api_keys or not is_api_key_email_address(user.email))
MinimalUserSnapshot(id=u.id, email=u.email)
for u in users
if u.account_type != AccountType.BOT
and (include_api_keys or not is_api_key_email_address(u.email))
]

View File

@@ -986,11 +986,21 @@ async def search_chats(
@router.post("/stop-chat-session/{chat_session_id}", tags=PUBLIC_API_TAGS)
def stop_chat_session(
chat_session_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""
Stop a chat session by setting a stop signal.
This endpoint is called by the frontend when the user clicks the stop button.
"""
try:
get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=user.id,
db_session=db_session,
)
except ValueError:
raise OnyxError(OnyxErrorCode.SESSION_NOT_FOUND, "Chat session not found")
set_fence(chat_session_id, get_cache_backend(), True)
return {"message": "Chat session stopped"}

View File

@@ -5,7 +5,8 @@ Usage:
source .venv/bin/activate
python backend/scripts/debugging/opensearch/opensearch_debug.py --help
python backend/scripts/debugging/opensearch/opensearch_debug.py list
python backend/scripts/debugging/opensearch/opensearch_debug.py delete <index_name>
python backend/scripts/debugging/opensearch/opensearch_debug.py delete
<index_name>
Environment Variables:
OPENSEARCH_HOST: OpenSearch host
@@ -17,10 +18,11 @@ Dependencies:
backend/shared_configs/configs.py
backend/onyx/document_index/opensearch/client.py
"""
import argparse
import json
import os
import sys
from typing import Any
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
@@ -61,6 +63,43 @@ def delete_index(client: OpenSearchIndexClient) -> None:
print(f"Failed to delete index '{client._index_name}' for an unknown reason.")
def get_settings(
client: OpenSearchIndexClient,
include_defaults: bool = False,
flat_settings: bool = False,
pretty: bool = False,
human: bool = False,
) -> None:
settings, default_settings = client.get_settings(
include_defaults=include_defaults,
flat_settings=flat_settings,
pretty=pretty,
human=human,
)
print("Settings:")
print(json.dumps(settings, indent=4))
print("-" * 80)
if default_settings:
print("Default settings:")
print(json.dumps(default_settings, indent=4))
print("-" * 80)
def set_settings(client: OpenSearchIndexClient, settings: dict[str, Any]) -> None:
client.update_settings(settings)
print(f"Updated settings for index '{client._index_name}'.")
def open_index(client: OpenSearchIndexClient) -> None:
client.open_index()
print(f"Index '{client._index_name}' opened.")
def close_index(client: OpenSearchIndexClient) -> None:
client.close_index()
print(f"Index '{client._index_name}' closed.")
def main() -> None:
def add_standard_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
@@ -77,13 +116,19 @@ def main() -> None:
)
parser.add_argument(
"--username",
help="OpenSearch username. If not provided, will fall back to OPENSEARCH_ADMIN_USERNAME, then prompt for input.",
help=(
"OpenSearch username. If not provided, will fall back to OPENSEARCH_ADMIN_USERNAME, then prompt for "
"input."
),
type=str,
default=os.environ.get("OPENSEARCH_ADMIN_USERNAME", ""),
)
parser.add_argument(
"--password",
help="OpenSearch password. If not provided, will fall back to OPENSEARCH_ADMIN_PASSWORD, then prompt for input.",
help=(
"OpenSearch password. If not provided, will fall back to OPENSEARCH_ADMIN_PASSWORD, then prompt for "
"input."
),
type=str,
default=os.environ.get("OPENSEARCH_ADMIN_PASSWORD", ""),
)
@@ -118,6 +163,47 @@ def main() -> None:
delete_parser = subparsers.add_parser("delete", help="Delete an index.")
delete_parser.add_argument("index", help="Index name.", type=str)
get_settings_parser = subparsers.add_parser(
"get", help="Get settings for an index."
)
get_settings_parser.add_argument("index", help="Index name.", type=str)
get_settings_parser.add_argument(
"--include-defaults",
help="Include default settings.",
action="store_true",
default=False,
)
get_settings_parser.add_argument(
"--flat-settings",
help="Return settings in flat format.",
action="store_true",
default=False,
)
get_settings_parser.add_argument(
"--pretty",
help="Pretty-format the returned JSON response.",
action="store_true",
default=False,
)
get_settings_parser.add_argument(
"--human",
help="Return statistics in human-readable format.",
action="store_true",
default=False,
)
set_settings_parser = subparsers.add_parser(
"set", help="Set settings for an index."
)
set_settings_parser.add_argument("index", help="Index name.", type=str)
set_settings_parser.add_argument("settings", help="Settings to set.", type=str)
open_index_parser = subparsers.add_parser("open", help="Open an index.")
open_index_parser.add_argument("index", help="Index name.", type=str)
close_index_parser = subparsers.add_parser("close", help="Close an index.")
close_index_parser.add_argument("index", help="Index name.", type=str)
args = parser.parse_args()
if not (host := args.host or input("Enter the OpenSearch host: ")):
@@ -134,18 +220,19 @@ def main() -> None:
sys.exit(1)
print("Using AWS-managed OpenSearch: ", args.use_aws_managed_opensearch)
print(f"MULTI_TENANT: {MULTI_TENANT}")
print()
with (
OpenSearchIndexClient(
index_name=args.index,
OpenSearchClient(
host=host,
port=port,
auth=(username, password),
use_ssl=not args.no_ssl,
verify_certs=not args.no_verify_certs,
)
if args.command == "delete"
else OpenSearchClient(
if args.command == "list"
else OpenSearchIndexClient(
index_name=args.index,
host=host,
port=port,
auth=(username, password),
@@ -161,6 +248,23 @@ def main() -> None:
list_indices(client)
elif args.command == "delete":
delete_index(client)
elif args.command == "get":
get_settings(
client,
include_defaults=args.include_defaults,
flat_settings=args.flat_settings,
pretty=args.pretty,
human=args.human,
)
elif args.command == "set":
set_settings(client, json.loads(args.settings))
elif args.command == "open":
open_index(client)
elif args.command == "close":
close_index(client)
else:
print(f"Unknown command: {args.command}")
sys.exit(1)
if __name__ == "__main__":

View File

@@ -456,7 +456,7 @@ class TestOpenSearchClient:
# Assert that the current number of replicas is not the desired test
# number we are updating to.
test_num_replicas = 0
current_settings = test_client.get_settings()
current_settings, _ = test_client.get_settings()
assert current_settings["index"]["number_of_replicas"] != f"{test_num_replicas}"
# Under test.
@@ -467,7 +467,7 @@ class TestOpenSearchClient:
)
# Postcondition.
current_settings = test_client.get_settings()
current_settings, _ = test_client.get_settings()
assert current_settings["index"]["number_of_replicas"] == f"{test_num_replicas}"
def test_update_settings_on_nonexistent_index(
@@ -488,7 +488,7 @@ class TestOpenSearchClient:
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
current_settings = test_client.get_settings()
current_settings, _ = test_client.get_settings()
# Postcondition.
assert "index" in current_settings

View File

@@ -183,3 +183,30 @@ def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
"""Verify unknown IDs return 404."""
response = _get_chat_session(str(uuid4()), basic_user)
assert response.status_code == 404
def _stop_chat_session(chat_session_id: str, user: DATestUser) -> requests.Response:
return requests.post(
f"{API_SERVER_URL}/chat/stop-chat-session/{chat_session_id}",
headers=user.headers,
cookies=user.cookies,
)
def test_stop_chat_session_rejects_non_owner(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Non-owner callers must not be able to stop another user's chat session."""
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
# Owner can stop their own session.
response = _stop_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 200
# A different authenticated user must not be able to stop it.
response = _stop_chat_session(str(chat_session.id), second_user)
assert response.status_code == 404
# Unknown session IDs are also rejected.
response = _stop_chat_session(str(uuid4()), second_user)
assert response.status_code == 404

View File

@@ -0,0 +1,47 @@
"""Unit tests for the OAuth reCAPTCHA cookie helpers in onyx.auth.captcha."""
from onyx.auth import captcha as captcha_module
def test_issued_cookie_validates() -> None:
"""A freshly issued cookie passes validation."""
cookie = captcha_module.issue_captcha_cookie_value()
assert captcha_module.validate_captcha_cookie_value(cookie) is True
def test_validate_rejects_none() -> None:
assert captcha_module.validate_captcha_cookie_value(None) is False
def test_validate_rejects_empty_string() -> None:
assert captcha_module.validate_captcha_cookie_value("") is False
def test_validate_rejects_malformed_no_separator() -> None:
assert captcha_module.validate_captcha_cookie_value("nodot") is False
def test_validate_rejects_non_numeric_expiry() -> None:
assert captcha_module.validate_captcha_cookie_value("notanumber.deadbeef") is False
def test_validate_rejects_tampered_signature() -> None:
"""Swapping the signature while keeping the expiry is rejected."""
cookie = captcha_module.issue_captcha_cookie_value()
expiry, _sig = cookie.split(".", 1)
tampered = f"{expiry}.deadbeefdeadbeefdeadbeefdeadbeef"
assert captcha_module.validate_captcha_cookie_value(tampered) is False
def test_validate_rejects_expired_timestamp() -> None:
"""An expiry in the past is rejected even with a valid signature."""
cookie = captcha_module.issue_captcha_cookie_value(now=0)
assert captcha_module.validate_captcha_cookie_value(cookie) is False
def test_validate_rejects_modified_expiry() -> None:
"""Bumping the expiry forward invalidates the signature."""
cookie = captcha_module.issue_captcha_cookie_value()
_expiry, sig = cookie.split(".", 1)
bumped = f"99999999999.{sig}"
assert captcha_module.validate_captcha_cookie_value(bumped) is False

View File

@@ -0,0 +1,172 @@
"""Unit tests for the reCAPTCHA token replay cache."""
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.auth import captcha as captcha_module
from onyx.auth.captcha import _replay_cache_key
from onyx.auth.captcha import _reserve_token_or_raise
from onyx.auth.captcha import CaptchaVerificationError
from onyx.auth.captcha import verify_captcha_token
@pytest.mark.asyncio
async def test_reserve_token_succeeds_on_first_use() -> None:
"""First SETNX claims the token; no exception."""
fake_redis = MagicMock()
fake_redis.set = AsyncMock(return_value=True)
with patch.object(
captcha_module,
"get_async_redis_connection",
AsyncMock(return_value=fake_redis),
):
await _reserve_token_or_raise("some-token")
fake_redis.set.assert_awaited_once()
await_args = fake_redis.set.await_args
assert await_args is not None
assert await_args.kwargs["nx"] is True
assert await_args.kwargs["ex"] == 120
@pytest.mark.asyncio
async def test_reserve_token_rejects_replay() -> None:
"""Second use of the same token within TTL → CaptchaVerificationError."""
fake_redis = MagicMock()
fake_redis.set = AsyncMock(return_value=False)
with patch.object(
captcha_module,
"get_async_redis_connection",
AsyncMock(return_value=fake_redis),
):
with pytest.raises(CaptchaVerificationError, match="token already used"):
await _reserve_token_or_raise("replayed-token")
@pytest.mark.asyncio
async def test_reserve_token_fails_open_on_redis_error() -> None:
"""A Redis blip must NOT block legitimate registrations."""
with patch.object(
captcha_module,
"get_async_redis_connection",
AsyncMock(side_effect=RuntimeError("redis down")),
):
# No exception raised — replay protection is gracefully skipped.
await _reserve_token_or_raise("any-token")
def test_replay_cache_key_is_sha256_prefixed() -> None:
"""The stored key never contains the raw token."""
key = _replay_cache_key("raw-value")
assert key.startswith("captcha:replay:")
assert "raw-value" not in key
# Length = prefix + 64 hex chars.
assert len(key) == len("captcha:replay:") + 64
@pytest.mark.asyncio
async def test_reservation_released_when_google_unreachable() -> None:
"""If Google's siteverify itself errors (our side, not the token's), the
replay reservation must be released so the user can retry with the same
still-valid token instead of getting 'already used' for 120s."""
fake_redis = MagicMock()
fake_redis.set = AsyncMock(return_value=True)
fake_redis.delete = AsyncMock(return_value=1)
fake_client = MagicMock()
fake_client.post = AsyncMock(side_effect=httpx.ConnectError("network down"))
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
fake_client.__aexit__ = AsyncMock(return_value=None)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(
captcha_module,
"get_async_redis_connection",
AsyncMock(return_value=fake_redis),
),
patch.object(captcha_module.httpx, "AsyncClient", return_value=fake_client),
):
with pytest.raises(CaptchaVerificationError, match="service unavailable"):
await verify_captcha_token("valid-token", expected_action="signup")
# The reservation was claimed and then released.
fake_redis.set.assert_awaited_once()
fake_redis.delete.assert_awaited_once()
@pytest.mark.asyncio
async def test_reservation_released_on_unexpected_response_shape() -> None:
"""Non-HTTP errors during response parsing (malformed JSON, pydantic
validation failure) also release the reservation — they mean WE couldn't
verify the token, not that the token is definitively invalid."""
fake_redis = MagicMock()
fake_redis.set = AsyncMock(return_value=True)
fake_redis.delete = AsyncMock(return_value=1)
# Simulate Google returning something that json() still succeeds on but
# fails RecaptchaResponse validation (e.g. success=true but with a wrong
# shape that Pydantic rejects when coerced).
fake_httpx_response = MagicMock()
fake_httpx_response.raise_for_status = MagicMock()
fake_httpx_response.json = MagicMock(side_effect=ValueError("not valid JSON"))
fake_client = MagicMock()
fake_client.post = AsyncMock(return_value=fake_httpx_response)
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
fake_client.__aexit__ = AsyncMock(return_value=None)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(
captcha_module,
"get_async_redis_connection",
AsyncMock(return_value=fake_redis),
),
patch.object(captcha_module.httpx, "AsyncClient", return_value=fake_client),
):
with pytest.raises(CaptchaVerificationError, match="service unavailable"):
await verify_captcha_token("valid-token", expected_action="signup")
fake_redis.set.assert_awaited_once()
fake_redis.delete.assert_awaited_once()
@pytest.mark.asyncio
async def test_reservation_kept_when_google_rejects_token() -> None:
"""If Google itself says the token is invalid (success=false, or score
too low), the reservation must NOT be released — that token is known-bad
for its entire lifetime and shouldn't be retryable."""
fake_redis = MagicMock()
fake_redis.set = AsyncMock(return_value=True)
fake_redis.delete = AsyncMock(return_value=1)
fake_httpx_response = MagicMock()
fake_httpx_response.raise_for_status = MagicMock()
fake_httpx_response.json = MagicMock(
return_value={
"success": False,
"error-codes": ["invalid-input-response"],
}
)
fake_client = MagicMock()
fake_client.post = AsyncMock(return_value=fake_httpx_response)
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
fake_client.__aexit__ = AsyncMock(return_value=None)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(
captcha_module,
"get_async_redis_connection",
AsyncMock(return_value=fake_redis),
),
patch.object(captcha_module.httpx, "AsyncClient", return_value=fake_client),
):
with pytest.raises(CaptchaVerificationError, match="invalid-input-response"):
await verify_captcha_token("bad-token", expected_action="signup")
fake_redis.set.assert_awaited_once()
fake_redis.delete.assert_not_awaited()

View File

@@ -0,0 +1,78 @@
"""Unit tests for the require-score check in verify_captcha_token."""
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.auth import captcha as captcha_module
from onyx.auth.captcha import CaptchaVerificationError
from onyx.auth.captcha import verify_captcha_token
def _fake_httpx_client_returning(payload: dict) -> MagicMock:
resp = MagicMock()
resp.raise_for_status = MagicMock()
resp.json = MagicMock(return_value=payload)
client = MagicMock()
client.post = AsyncMock(return_value=resp)
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
return client
@pytest.mark.asyncio
async def test_rejects_when_score_missing() -> None:
"""Siteverify response with no score field is rejected outright —
closes the accidental 'test secret in prod' bypass path."""
client = _fake_httpx_client_returning(
{"success": True, "hostname": "testkey.google.com"}
)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
):
with pytest.raises(CaptchaVerificationError, match="missing score"):
await verify_captcha_token("test-token", expected_action="signup")
@pytest.mark.asyncio
async def test_accepts_when_score_present_and_above_threshold() -> None:
"""Sanity check the happy path still works with the tighter score rule."""
client = _fake_httpx_client_returning(
{
"success": True,
"score": 0.9,
"action": "signup",
"hostname": "cloud.onyx.app",
}
)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
):
# Should not raise.
await verify_captcha_token("fresh-token", expected_action="signup")
@pytest.mark.asyncio
async def test_rejects_when_score_below_threshold() -> None:
"""A score present but below threshold still rejects (existing behavior,
guarding against regression from this PR's restructure)."""
client = _fake_httpx_client_returning(
{
"success": True,
"score": 0.1,
"action": "signup",
"hostname": "cloud.onyx.app",
}
)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
):
with pytest.raises(
CaptchaVerificationError, match="suspicious activity detected"
):
await verify_captcha_token("low-score-token", expected_action="signup")

View File

@@ -0,0 +1,210 @@
"""Unit tests for the reCAPTCHA OAuth verify endpoint + cookie middleware."""
from unittest.mock import AsyncMock
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from onyx.auth.captcha import CaptchaVerificationError
from onyx.error_handling.exceptions import register_onyx_exception_handlers
from onyx.server.auth import captcha_api as captcha_api_module
from onyx.server.auth.captcha_api import CaptchaCookieMiddleware
from onyx.server.auth.captcha_api import router as captcha_router
def build_app_with_middleware() -> FastAPI:
"""Minimal app with the middleware + router + fake OAuth callback route."""
app = FastAPI()
register_onyx_exception_handlers(app)
app.add_middleware(CaptchaCookieMiddleware)
app.include_router(captcha_router)
@app.get("/auth/oauth/callback")
async def _oauth_callback() -> dict[str, str]:
return {"status": "ok"}
@app.post("/auth/register")
async def _register() -> dict[str, str]:
# /auth/register is NOT gated by this middleware (it has its own
# captcha enforcement in UserManager.create). Used here to prove the
# middleware only touches /auth/oauth/callback.
return {"status": "created"}
@app.get("/me")
async def _me() -> dict[str, str]:
return {"status": "not-guarded"}
return app
# ---------- /auth/captcha/oauth-verify endpoint ----------
def test_verify_endpoint_returns_ok_when_captcha_disabled() -> None:
"""Dormant mode: endpoint is a no-op, no cookie issued."""
app = build_app_with_middleware()
client = TestClient(app)
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=False):
res = client.post("/auth/captcha/oauth-verify", json={"token": "whatever"})
assert res.status_code == 200
assert res.json() == {"ok": True}
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
assert CAPTCHA_COOKIE_NAME not in res.cookies
def test_verify_endpoint_sets_cookie_on_success() -> None:
app = build_app_with_middleware()
client = TestClient(app)
with (
patch.object(captcha_api_module, "is_captcha_enabled", return_value=True),
patch.object(
captcha_api_module,
"verify_captcha_token",
AsyncMock(return_value=None),
),
):
res = client.post("/auth/captcha/oauth-verify", json={"token": "valid-token"})
assert res.status_code == 200
assert res.json() == {"ok": True}
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
assert CAPTCHA_COOKIE_NAME in res.cookies
def test_verify_endpoint_raises_onyx_error_on_failure() -> None:
app = build_app_with_middleware()
client = TestClient(app)
with (
patch.object(captcha_api_module, "is_captcha_enabled", return_value=True),
patch.object(
captcha_api_module,
"verify_captcha_token",
AsyncMock(
side_effect=CaptchaVerificationError(
"Captcha verification failed: invalid-input-response"
)
),
),
):
res = client.post("/auth/captcha/oauth-verify", json={"token": "bad-token"})
assert res.status_code == 403
body = res.json()
assert body["error_code"] == "UNAUTHORIZED"
assert "invalid-input-response" in body["detail"]
def test_verify_endpoint_rejects_missing_token() -> None:
app = build_app_with_middleware()
client = TestClient(app)
res = client.post("/auth/captcha/oauth-verify", json={})
# Pydantic validation failure from missing `token`.
assert res.status_code == 422
# ---------- CaptchaCookieMiddleware ----------
def test_middleware_passes_through_when_captcha_disabled() -> None:
app = build_app_with_middleware()
client = TestClient(app)
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=False):
res = client.get("/auth/oauth/callback")
assert res.status_code == 200
assert res.json() == {"status": "ok"}
def test_middleware_blocks_oauth_callback_without_cookie() -> None:
app = build_app_with_middleware()
client = TestClient(app)
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
res = client.get("/auth/oauth/callback")
assert res.status_code == 403
body = res.json()
assert body["error_code"] == "UNAUTHORIZED"
assert "Captcha challenge required" in body["detail"]
def test_middleware_allows_oauth_callback_with_valid_cookie() -> None:
"""A correctly-signed unexpired cookie lets the OAuth callback through."""
app = build_app_with_middleware()
client = TestClient(app)
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
cookie_value = captcha_api_module.issue_captcha_cookie_value()
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
res = client.get(
"/auth/oauth/callback",
cookies={CAPTCHA_COOKIE_NAME: cookie_value},
)
assert res.status_code == 200
assert res.json() == {"status": "ok"}
def test_middleware_clears_cookie_after_successful_callback() -> None:
"""One-time-use: cookie is deleted after the callback has been served so
a replayed callback URL cannot re-enter without a fresh challenge."""
app = build_app_with_middleware()
client = TestClient(app)
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
cookie_value = captcha_api_module.issue_captcha_cookie_value()
res = client.get(
"/auth/oauth/callback",
cookies={CAPTCHA_COOKIE_NAME: cookie_value},
)
assert res.status_code == 200
set_cookie = res.headers.get("set-cookie", "")
# Starlette's delete_cookie emits an expired Max-Age=0 Set-Cookie for the name.
assert CAPTCHA_COOKIE_NAME in set_cookie
assert (
"Max-Age=0" in set_cookie or 'expires="Thu, 01 Jan 1970' in set_cookie.lower()
)
def test_middleware_rejects_tampered_cookie() -> None:
app = build_app_with_middleware()
client = TestClient(app)
from onyx.auth.captcha import CAPTCHA_COOKIE_NAME
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
res = client.get(
"/auth/oauth/callback",
cookies={CAPTCHA_COOKIE_NAME: "9999999999.deadbeef"},
)
assert res.status_code == 403
def test_middleware_ignores_register_path() -> None:
"""/auth/register has its own captcha enforcement in UserManager.create —
the cookie middleware should NOT gate it."""
app = build_app_with_middleware()
client = TestClient(app)
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
res = client.post("/auth/register", json={})
assert res.status_code == 200
def test_middleware_ignores_unrelated_paths() -> None:
app = build_app_with_middleware()
client = TestClient(app)
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
res = client.get("/me")
assert res.status_code == 200
def test_middleware_skips_options_preflight() -> None:
"""CORS preflight must pass through even without a cookie."""
app = build_app_with_middleware()
client = TestClient(app)
with patch.object(captcha_api_module, "is_captcha_enabled", return_value=True):
res = client.options("/auth/oauth/callback")
# Not 403: preflight passed the captcha gate.
assert res.status_code != 403
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -7,9 +7,12 @@ import pytest
from fastapi import HTTPException
from onyx.server.manage.models import EmailInviteStatus
from onyx.server.manage.models import UserByEmail
from onyx.server.manage.users import bulk_invite_users
from onyx.server.manage.users import remove_invited_user
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
@@ -21,12 +24,13 @@ def test_trial_tenant_cannot_exceed_invite_limit(*_mocks: None) -> None:
emails = [f"user{i}@example.com" for i in range(6)]
with pytest.raises(HTTPException) as exc_info:
bulk_invite_users(emails=emails)
bulk_invite_users(emails=emails, current_user=MagicMock())
assert exc_info.value.status_code == 403
assert "invite limit" in exc_info.value.detail.lower()
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.DEV_MODE", True)
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
@@ -45,7 +49,7 @@ def test_trial_tenant_can_invite_within_limit(*_mocks: None) -> None:
"""Trial tenants can invite users when under the limit."""
emails = ["user1@example.com", "user2@example.com", "user3@example.com"]
result = bulk_invite_users(emails=emails)
result = bulk_invite_users(emails=emails, current_user=MagicMock())
assert result.invited_count == 3
assert result.email_invite_status == EmailInviteStatus.DISABLED
@@ -60,6 +64,7 @@ _COMMON_PATCHES = [
patch("onyx.server.manage.users.get_all_users", return_value=[]),
patch("onyx.server.manage.users.write_invited_users", return_value=1),
patch("onyx.server.manage.users.enforce_seat_limit"),
patch("onyx.server.manage.users.enforce_invite_rate_limit"),
]
@@ -73,7 +78,7 @@ def _with_common_patches(fn: object) -> object:
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
def test_email_invite_status_disabled(*_mocks: None) -> None:
"""When email invites are disabled, status is disabled."""
result = bulk_invite_users(emails=["user@example.com"])
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
assert result.email_invite_status == EmailInviteStatus.DISABLED
@@ -83,7 +88,7 @@ def test_email_invite_status_disabled(*_mocks: None) -> None:
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", False)
def test_email_invite_status_not_configured(*_mocks: None) -> None:
"""When email invites are enabled but no server is configured, status is not_configured."""
result = bulk_invite_users(emails=["user@example.com"])
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
assert result.email_invite_status == EmailInviteStatus.NOT_CONFIGURED
@@ -94,7 +99,7 @@ def test_email_invite_status_not_configured(*_mocks: None) -> None:
@patch("onyx.server.manage.users.send_user_email_invite")
def test_email_invite_status_sent(mock_send: MagicMock, *_mocks: None) -> None:
"""When email invites are enabled and configured, status is sent."""
result = bulk_invite_users(emails=["user@example.com"])
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
mock_send.assert_called_once()
assert result.email_invite_status == EmailInviteStatus.SENT
@@ -109,7 +114,123 @@ def test_email_invite_status_sent(mock_send: MagicMock, *_mocks: None) -> None:
)
def test_email_invite_status_send_failed(*_mocks: None) -> None:
"""When email sending throws, status is send_failed and invite is still saved."""
result = bulk_invite_users(emails=["user@example.com"])
result = bulk_invite_users(emails=["user@example.com"], current_user=MagicMock())
assert result.email_invite_status == EmailInviteStatus.SEND_FAILED
assert result.invited_count == 1
# --- trial-only rate limit gating tests ---
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.DEV_MODE", True)
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=False)
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
@patch("onyx.server.manage.users.get_all_users", return_value=[])
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
@patch("onyx.server.manage.users.enforce_seat_limit")
@patch(
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
return_value=lambda *_args: None,
)
def test_paid_tenant_bypasses_invite_rate_limit(
_ee_fetch: MagicMock,
_seat_limit: MagicMock,
_write_invited: MagicMock,
_get_all_users: MagicMock,
_get_invited_users: MagicMock,
_get_tenant_id: MagicMock,
_is_trial: MagicMock,
mock_rate_limit: MagicMock,
) -> None:
"""Paid tenants must not hit the invite rate limiter at all."""
emails = [f"user{i}@example.com" for i in range(3)]
bulk_invite_users(emails=emails, current_user=MagicMock())
mock_rate_limit.assert_not_called()
@patch("onyx.server.manage.users.enforce_invite_rate_limit")
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.DEV_MODE", True)
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
@patch("onyx.server.manage.users.get_all_users", return_value=[])
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
@patch("onyx.server.manage.users.enforce_seat_limit")
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 50)
@patch(
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
return_value=lambda *_args: None,
)
def test_trial_tenant_hits_invite_rate_limit(
_ee_fetch: MagicMock,
_seat_limit: MagicMock,
_write_invited: MagicMock,
_get_all_users: MagicMock,
_get_invited_users: MagicMock,
_get_tenant_id: MagicMock,
_is_trial: MagicMock,
mock_rate_limit: MagicMock,
) -> None:
"""Trial tenants must flow through the invite rate limiter."""
emails = [f"user{i}@example.com" for i in range(3)]
bulk_invite_users(emails=emails, current_user=MagicMock())
mock_rate_limit.assert_called_once()
@patch("onyx.server.manage.users.enforce_remove_invited_rate_limit")
@patch("onyx.server.manage.users.remove_user_from_invited_users", return_value=0)
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.DEV_MODE", True)
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=False)
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
@patch(
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
return_value=lambda *_args: None,
)
def test_paid_tenant_bypasses_remove_invited_rate_limit(
_ee_fetch: MagicMock,
_get_tenant_id: MagicMock,
_is_trial: MagicMock,
_remove_from_invited: MagicMock,
mock_rate_limit: MagicMock,
) -> None:
"""Paid tenants must not hit the remove-invited rate limiter at all."""
remove_invited_user(
user_email=UserByEmail(user_email="user@example.com"),
current_user=MagicMock(),
db_session=MagicMock(),
)
mock_rate_limit.assert_not_called()
@patch("onyx.server.manage.users.enforce_remove_invited_rate_limit")
@patch("onyx.server.manage.users.remove_user_from_invited_users", return_value=0)
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.DEV_MODE", True)
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
@patch(
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
return_value=lambda *_args: None,
)
def test_trial_tenant_hits_remove_invited_rate_limit(
_ee_fetch: MagicMock,
_get_tenant_id: MagicMock,
_is_trial: MagicMock,
_remove_from_invited: MagicMock,
mock_rate_limit: MagicMock,
) -> None:
"""Trial tenants must flow through the remove-invited rate limiter."""
remove_invited_user(
user_email=UserByEmail(user_email="user@example.com"),
current_user=MagicMock(),
db_session=MagicMock(),
)
mock_rate_limit.assert_called_once()

View File

@@ -0,0 +1,374 @@
"""Unit tests for the Redis-backed invite + remove-invited rate limits."""
from typing import Any
from typing import cast
from unittest.mock import patch
from uuid import uuid4
import pytest
from redis import Redis
from redis.exceptions import ConnectionError as RedisConnectionError
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.invite_rate_limit import enforce_invite_rate_limit
from onyx.server.manage.invite_rate_limit import enforce_remove_invited_rate_limit
class _StubRedis:
"""In-memory stand-in that mirrors the Lua script's semantics.
The production rate limiter drives all state via a single EVAL call.
The stub reimplements that logic in Python so unit tests can assert
behavior without a live Redis — matching semantics including the
NX-style TTL that leaves existing TTLs intact on re-increment.
"""
def __init__(self) -> None:
self.store: dict[str, int] = {}
self.ttls: dict[str, int] = {}
self.eval_fail: Exception | None = None
def eval(self, _script: str, num_keys: int, *args: Any) -> int:
if self.eval_fail is not None:
raise self.eval_fail
keys = list(args[:num_keys])
argv = list(args[num_keys:])
n = int(argv[0])
for i in range(n):
key = keys[i]
increment = int(argv[1 + i * 3])
limit = int(argv[2 + i * 3])
if limit > 0 and increment > 0:
current = self.store.get(key, 0)
if current + increment > limit:
return i + 1
for i in range(n):
key = keys[i]
increment = int(argv[1 + i * 3])
limit = int(argv[2 + i * 3])
ttl = int(argv[3 + i * 3])
if limit > 0 and increment > 0:
self.store[key] = self.store.get(key, 0) + increment
if key not in self.ttls:
self.ttls[key] = ttl
return 0
def _stub() -> Redis:
return cast(Redis, _StubRedis())
def test_invite_allows_under_all_tiers() -> None:
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
500,
),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
)
stub = cast(_StubRedis, redis_client)
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:day"] == 10
assert stub.store["ratelimit:invite_put:tenant:tenant_a:day"] == 10
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:min"] == 1
def test_invite_minute_bucket_blocks_request_flood() -> None:
"""Attacker firing single-email invites rapidly must trip admin/minute."""
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
5000,
),
):
for _ in range(5):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
)
with pytest.raises(OnyxError) as exc_info:
enforce_invite_rate_limit(
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
)
assert exc_info.value.error_code == OnyxErrorCode.RATE_LIMITED
def test_invite_bulk_call_does_not_trip_minute_bucket() -> None:
"""Legitimate one-shot bulk call for many users should not hit minute cap."""
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
500,
),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=20, tenant_id="tenant_a"
)
stub = cast(_StubRedis, redis_client)
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:min"] == 1
def test_invite_admin_daily_cap_enforced() -> None:
redis_client = _stub()
user_id = uuid4()
with (
patch(
"onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN",
1000,
),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
5000,
),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=50, tenant_id="tenant_a"
)
with pytest.raises(OnyxError):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
)
def test_invite_tenant_daily_cap_enforced_across_admins() -> None:
"""Tenant cap should trip even when traffic comes from multiple admins."""
redis_client = _stub()
with (
patch(
"onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN",
1000,
),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY",
1000,
),
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 10),
):
enforce_invite_rate_limit(
redis_client, uuid4(), num_invites=6, tenant_id="tenant_a"
)
enforce_invite_rate_limit(
redis_client, uuid4(), num_invites=4, tenant_id="tenant_a"
)
with pytest.raises(OnyxError):
enforce_invite_rate_limit(
redis_client, uuid4(), num_invites=1, tenant_id="tenant_a"
)
def test_invite_rejected_request_does_not_consume_budget() -> None:
"""A request that violates a tier must not increment the surviving tiers."""
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 5),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 50),
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 10),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
)
with pytest.raises(OnyxError):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=5, tenant_id="tenant_a"
)
stub = cast(_StubRedis, redis_client)
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:day"] == 10
assert stub.store["ratelimit:invite_put:tenant:tenant_a:day"] == 10
assert stub.store[f"ratelimit:invite_put:admin:{user_id}:min"] == 1
def test_invite_zero_new_invites_still_ticks_minute_bucket() -> None:
"""Probes against already-invited emails must still tick the burst guard."""
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 2),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
5000,
),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=0, tenant_id="tenant_a"
)
enforce_invite_rate_limit(
redis_client, user_id, num_invites=0, tenant_id="tenant_a"
)
with pytest.raises(OnyxError):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=0, tenant_id="tenant_a"
)
stub = cast(_StubRedis, redis_client)
assert stub.store.get(f"ratelimit:invite_put:admin:{user_id}:day", 0) == 0
assert stub.store.get("ratelimit:invite_put:tenant:tenant_a:day", 0) == 0
def test_invite_limit_zero_disables_tier() -> None:
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 0),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 0),
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 0),
):
for _ in range(100):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
)
def test_invite_tenant_bucket_is_isolated_across_tenants() -> None:
"""Regression guard: tenants MUST NOT share the tenant/day counter.
TenantRedis does not prefix keys passed to `eval`, so the tenant_id is
baked into the key string itself."""
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 1000),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 1000),
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 10),
):
# Tenant A exhausts its own cap.
enforce_invite_rate_limit(
redis_client, user_id, num_invites=10, tenant_id="tenant_a"
)
with pytest.raises(OnyxError):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=1, tenant_id="tenant_a"
)
# Tenant B must still have its full budget.
enforce_invite_rate_limit(
redis_client, uuid4(), num_invites=10, tenant_id="tenant_b"
)
stub = cast(_StubRedis, redis_client)
assert stub.store["ratelimit:invite_put:tenant:tenant_a:day"] == 10
assert stub.store["ratelimit:invite_put:tenant:tenant_b:day"] == 10
def test_invite_fails_open_when_redis_unavailable() -> None:
"""Onyx Lite deployments ship without Redis; invite flow must still work."""
stub = _StubRedis()
stub.eval_fail = RedisConnectionError("Redis not reachable")
redis_client = cast(Redis, stub)
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 1),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 1),
patch("onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY", 1),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=1_000_000, tenant_id="tenant_a"
)
def test_remove_minute_bucket_blocks_pattern_attack() -> None:
"""PUT→PATCH spam must trip the remove-invited minute bucket."""
redis_client = _stub()
user_id = uuid4()
with (
patch(
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_MIN",
3,
),
patch(
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_DAY",
100,
),
):
for _ in range(3):
enforce_remove_invited_rate_limit(redis_client, user_id)
with pytest.raises(OnyxError):
enforce_remove_invited_rate_limit(redis_client, user_id)
def test_remove_daily_cap_enforced() -> None:
redis_client = _stub()
user_id = uuid4()
with (
patch(
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_MIN",
1000,
),
patch(
"onyx.server.manage.invite_rate_limit._REMOVE_ADMIN_PER_DAY",
5,
),
):
for _ in range(5):
enforce_remove_invited_rate_limit(redis_client, user_id)
with pytest.raises(OnyxError):
enforce_remove_invited_rate_limit(redis_client, user_id)
def test_ttls_set_on_first_increment_and_not_reset() -> None:
"""TTL must be set on the first increment and must not be reset on later ones."""
redis_client = _stub()
user_id = uuid4()
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 100),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
5000,
),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=3, tenant_id="tenant_a"
)
stub = cast(_StubRedis, redis_client)
assert stub.ttls[f"ratelimit:invite_put:admin:{user_id}:day"] == 24 * 60 * 60
assert stub.ttls["ratelimit:invite_put:tenant:tenant_a:day"] == 24 * 60 * 60
assert stub.ttls[f"ratelimit:invite_put:admin:{user_id}:min"] == 60
stub.ttls[f"ratelimit:invite_put:admin:{user_id}:min"] = 999
with (
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_MIN", 100),
patch("onyx.server.manage.invite_rate_limit._INVITE_ADMIN_PER_DAY", 500),
patch(
"onyx.server.manage.invite_rate_limit._INVITE_TENANT_PER_DAY",
5000,
),
):
enforce_invite_rate_limit(
redis_client, user_id, num_invites=3, tenant_id="tenant_a"
)
assert stub.ttls[f"ratelimit:invite_put:admin:{user_id}:min"] == 999

View File

@@ -0,0 +1,79 @@
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.db.enums import AccountType
from onyx.db.enums import Permission
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.users import list_all_users_basic_info
def _fake_user(
email: str, account_type: AccountType = AccountType.STANDARD
) -> MagicMock:
user = MagicMock()
user.id = uuid4()
user.email = email
user.account_type = account_type
return user
@patch("onyx.server.manage.users.USER_DIRECTORY_ADMIN_ONLY", True)
def test_list_all_users_basic_info_blocks_non_admin_when_directory_restricted() -> None:
"""With the flag on, a caller lacking READ_USERS cannot enumerate the directory."""
user = MagicMock()
user.effective_permissions = [Permission.BASIC_ACCESS.value]
with pytest.raises(OnyxError) as exc_info:
list_all_users_basic_info(
include_api_keys=False,
user=user,
db_session=MagicMock(),
)
assert exc_info.value.error_code is OnyxErrorCode.INSUFFICIENT_PERMISSIONS
@patch("onyx.server.manage.users.USER_DIRECTORY_ADMIN_ONLY", True)
@patch("onyx.server.manage.users.get_all_users")
def test_list_all_users_basic_info_allows_admin_when_directory_restricted(
mock_get_all_users: MagicMock,
) -> None:
"""With the flag on, an admin (FULL_ADMIN_PANEL_ACCESS) still gets the directory."""
admin = MagicMock()
admin.effective_permissions = [Permission.FULL_ADMIN_PANEL_ACCESS.value]
mock_get_all_users.return_value = [_fake_user("a@example.com")]
result = list_all_users_basic_info(
include_api_keys=False,
user=admin,
db_session=MagicMock(),
)
assert [u.email for u in result] == ["a@example.com"]
@patch("onyx.server.manage.users.USER_DIRECTORY_ADMIN_ONLY", False)
@patch("onyx.server.manage.users.get_all_users")
def test_list_all_users_basic_info_allows_non_admin_when_flag_off(
mock_get_all_users: MagicMock,
) -> None:
"""With the flag off (default), non-admin callers continue to get the directory."""
basic = MagicMock()
basic.effective_permissions = [Permission.BASIC_ACCESS.value]
mock_get_all_users.return_value = [
_fake_user("human@example.com"),
_fake_user("bot@example.com", account_type=AccountType.BOT),
]
result = list_all_users_basic_info(
include_api_keys=False,
user=basic,
db_session=MagicMock(),
)
# BOT accounts are filtered out; human account is returned.
assert [u.email for u in result] == ["human@example.com"]

View File

@@ -40,6 +40,7 @@ services:
- SMTP_USER=${SMTP_USER:-}
- SMTP_PASS=${SMTP_PASS:-}
- ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-}
- USER_DIRECTORY_ADMIN_ONLY=${USER_DIRECTORY_ADMIN_ONLY:-}
- EMAIL_FROM=${EMAIL_FROM:-}
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}

View File

@@ -49,6 +49,11 @@ SESSION_EXPIRE_TIME_SECONDS=604800
#VALID_EMAIL_DOMAINS=
# Set to "true" to restrict GET /users to admins. Non-admins can't enumerate
# accounts or share agents with individual users; group sharing still works.
#USER_DIRECTORY_ADMIN_ONLY=
# Default values here are what Postgres uses by default, feel free to change.
POSTGRES_USER=postgres
POSTGRES_PASSWORD=password

View File

@@ -32,6 +32,9 @@ USER_AUTH_SECRET=""
# API_KEY_HASH_ROUNDS=
### You can add a comma separated list of domains like onyx.app, only those domains will be allowed to signup/log in
# VALID_EMAIL_DOMAINS=
### Set to "true" to restrict GET /users to admins. Non-admins can't enumerate
### accounts or share agents with individual users; group sharing still works.
# USER_DIRECTORY_ADMIN_ONLY=
## Chat Configuration
# HARD_DELETE_CHATS=
@@ -172,7 +175,7 @@ LOG_ONYX_MODEL_INTERACTIONS=False
## Gen AI Settings
# GEN_AI_MAX_TOKENS=
# LLM_SOCKET_READ_TIMEOUT=
LLM_SOCKET_READ_TIMEOUT=120
# MAX_CHUNKS_FED_TO_CHAT=
# DISABLE_LITELLM_STREAMING=
# LITELLM_EXTRA_HEADERS=

View File

@@ -1262,7 +1262,7 @@ configMap:
S3_FILE_STORE_BUCKET_NAME: ""
# Gen AI Settings
GEN_AI_MAX_TOKENS: ""
LLM_SOCKET_READ_TIMEOUT: "60"
LLM_SOCKET_READ_TIMEOUT: "120"
MAX_CHUNKS_FED_TO_CHAT: ""
# Query Options
DOC_TIME_DECAY: ""

View File

@@ -1,6 +1,14 @@
/**
* SignInButton — renders the SSO / OAuth sign-in button on the login page.
*
* When reCAPTCHA is enabled for this deployment (NEXT_PUBLIC_RECAPTCHA_SITE_KEY
* set at build time), the Google/OIDC/SAML OAuth click is intercepted to
* (1) fetch a reCAPTCHA v3 token for the "oauth" action, (2) POST it to
* /api/auth/captcha/oauth-verify which sets a signed HttpOnly cookie on the
* response, and (3) then navigate to the authorize URL. The cookie is sent
* automatically on the subsequent Google redirect back to our callback,
* where the backend middleware verifies it.
*
* IMPORTANT: This component is rendered as part of the /auth/login page, which
* is used in healthcheck and monitoring flows that issue headless (non-browser)
* requests (e.g. `curl`). During server-side rendering of those requests,
@@ -18,10 +26,13 @@
"use client";
import { useState } from "react";
import { Button } from "@opal/components";
import { AuthType } from "@/lib/constants";
import { FcGoogle } from "react-icons/fc";
import type { IconProps } from "@opal/types";
import { useCaptcha } from "@/lib/hooks/useCaptcha";
import Text from "@/refresh-components/texts/Text";
interface SignInButtonProps {
authorizeUrl: string;
@@ -32,6 +43,10 @@ export default function SignInButton({
authorizeUrl,
authType,
}: SignInButtonProps) {
const { getCaptchaToken, isCaptchaEnabled } = useCaptcha();
const [isVerifying, setIsVerifying] = useState(false);
const [error, setError] = useState<string | null>(null);
let button: string | undefined;
let icon: React.FunctionComponent<IconProps> | undefined;
@@ -48,18 +63,84 @@ export default function SignInButton({
throw new Error(`Unhandled authType: ${authType}`);
}
return (
<Button
prominence={
authType === AuthType.GOOGLE_OAUTH || authType === AuthType.CLOUD
? "secondary"
: "primary"
async function handleClick(e: React.MouseEvent) {
e.preventDefault();
if (isVerifying) return;
setIsVerifying(true);
setError(null);
// Stays true on the success branch so the button remains disabled until
// the browser actually begins unloading for the OAuth redirect — prevents
// a double-click window between `window.location.href = ...` and unload.
let navigating = false;
try {
const token = await getCaptchaToken("oauth");
if (!token) {
// eslint-disable-next-line no-console
console.error(
"Captcha: grecaptcha.execute returned no token. The widget may not have loaded yet."
);
setError("grecaptcha.execute returned no token");
return;
}
width="full"
icon={icon}
href={authorizeUrl}
>
{button}
</Button>
const res = await fetch("/api/auth/captcha/oauth-verify", {
method: "POST",
credentials: "include",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ token }),
});
if (!res.ok) {
const body = await res.json().catch(() => ({}));
// eslint-disable-next-line no-console
console.error(
`Captcha verify rejected: status=${res.status} detail=${
body.detail ?? "(none)"
}`
);
setError(
"Captcha verification failed. Please refresh your browser and try again."
);
return;
}
navigating = true;
window.location.href = authorizeUrl;
} catch (exc) {
// eslint-disable-next-line no-console
console.error("Captcha verify request failed", exc);
setError(exc instanceof Error ? exc.message : String(exc));
} finally {
if (!navigating) setIsVerifying(false);
}
}
// Only the Google OAuth callback is gated by CaptchaCookieMiddleware on the
// backend. OIDC/SAML callbacks have no cookie requirement, so running the
// reCAPTCHA interception for them is wasted friction — and worse, a failed
// captcha would block the sign-in entirely.
const intercepted =
isCaptchaEnabled &&
(authType === AuthType.GOOGLE_OAUTH || authType === AuthType.CLOUD);
return (
<>
<Button
prominence={
authType === AuthType.GOOGLE_OAUTH || authType === AuthType.CLOUD
? "secondary"
: "primary"
}
width="full"
icon={icon}
href={intercepted ? undefined : authorizeUrl}
onClick={intercepted ? handleClick : undefined}
disabled={isVerifying}
>
{button}
</Button>
{error && (
<Text as="p" mainUiMuted className="text-status-error-05 mt-2">
{error}
</Text>
)}
</>
);
}

View File

@@ -27,8 +27,10 @@ import { useUser } from "@/providers/UserProvider";
import { Formik, useFormikContext } from "formik";
import { useAgent } from "@/hooks/useAgents";
import { Button, MessageCard } from "@opal/components";
import { Disabled } from "@opal/core";
import { useLabels } from "@/lib/hooks";
import { PersonaLabel } from "@/app/admin/agents/interfaces";
import { FetchError } from "@/lib/fetcher";
const YOUR_ORGANIZATION_TAB = "Your Organization";
const USERS_AND_GROUPS_TAB = "Users & Groups";
@@ -56,8 +58,12 @@ interface ShareAgentFormContentProps {
function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
const { values, setFieldValue, handleSubmit, dirty, isSubmitting } =
useFormikContext<ShareAgentFormValues>();
const { data: usersData } = useShareableUsers({ includeApiKeys: true });
const { data: usersData, error: usersError } = useShareableUsers({
includeApiKeys: true,
});
const { data: groupsData } = useShareableGroups();
const userDirectoryRestricted =
usersError instanceof FetchError && usersError.status === 403;
const { user: currentUser, isAdmin, isCurator } = useUser();
const { agent: fullAgent } = useAgent(agentId ?? null);
const shareAgentModal = useModal();
@@ -70,12 +76,14 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
// Create options for InputComboBox from all accepted users and groups
const comboBoxOptions = useMemo(() => {
const userOptions = acceptedUsers
.filter((user) => user.id !== currentUser?.id)
.map((user) => ({
value: `user-${user.id}`,
label: user.email,
}));
const userOptions = userDirectoryRestricted
? []
: acceptedUsers
.filter((user) => user.id !== currentUser?.id)
.map((user) => ({
value: `user-${user.id}`,
label: user.email,
}));
const groupOptions = groups.map((group) => ({
value: `group-${group.id}`,
@@ -83,7 +91,10 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
}));
return [...userOptions, ...groupOptions];
}, [acceptedUsers, groups, currentUser?.id]);
}, [acceptedUsers, groups, currentUser?.id, userDirectoryRestricted]);
const comboBoxDisabled =
userDirectoryRestricted && comboBoxOptions.length === 0;
// Compute owner and displayed users
const ownerId = fullAgent?.owner?.id;
@@ -214,14 +225,31 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
<Tabs.Content value={USERS_AND_GROUPS_TAB}>
<Section gap={0.5} alignItems="start">
<InputComboBox
placeholder="Add users and groups"
value=""
onChange={() => {}}
onValueChange={handleComboBoxSelect}
options={comboBoxOptions}
strict
/>
<Disabled
disabled={comboBoxDisabled}
tooltip={
comboBoxDisabled
? "Your administrator has restricted the user directory. Contact an admin to share this agent with other users."
: undefined
}
tooltipSide="bottom"
>
<div className="w-full">
<InputComboBox
placeholder={
userDirectoryRestricted
? "Add groups"
: "Add users and groups"
}
value=""
onChange={() => {}}
onValueChange={handleComboBoxSelect}
options={comboBoxOptions}
strict
disabled={comboBoxDisabled}
/>
</div>
</Disabled>
{(displayedUsers.length > 0 || displayedGroups.length > 0) && (
<Section gap={0} alignItems="stretch">
{/* Shared Users */}