Compare commits

..

6 Commits

Author SHA1 Message Date
Nik
0cf3f22b3b fix(be): change CREDENTIAL_EXPIRED status from 400 to 401
An expired credential is an auth failure, not a malformed request.
401 aligns with UNAUTHENTICATED/INVALID_TOKEN/TOKEN_EXPIRED and
RFC 9110 semantics.
2026-03-03 13:58:41 -08:00
Nik
77fd9bb052 fix(be): wrap BillingServiceError in structured detail format
BillingServiceError catch blocks now return the same
{"error_code": "BILLING_SERVICE_ERROR", "message": "..."} dict shape
as all other error responses, while preserving the dynamic upstream
status code. Also documents this pattern in AGENTS.md.
2026-03-03 13:38:11 -08:00
Nik
42f723d17c fix(test): add isinstance type narrowing for mypy compatibility
HTTPException.detail is typed as str by starlette, but we pass dicts.
Add isinstance assertions before dict access to satisfy mypy.
2026-03-03 13:14:20 -08:00
Nik
d261f24578 fix(be): prevent exception masking and info leakage in billing_api
- Add `except HTTPException: raise` in create_subscription_session to
  prevent the broad `except Exception` from swallowing the 400
  VALIDATION_ERROR and re-raising it as 500 INTERNAL_ERROR.
- Replace `str(e)` in error details with generic messages to avoid
  leaking internal exception details to API consumers.
2026-03-03 12:19:53 -08:00
Nik
b670ac69e8 fix(be): update billing test assertions for structured error details
Tests now check detail["message"] instead of plain string detail
since OnyxErrorCode.detail() returns a dict.
2026-03-03 11:11:01 -08:00
Nik
31d95908c4 refactor(be): add OnyxErrorCode enum and migrate billing/license routers
Introduces a centralized OnyxErrorCode enum in
backend/onyx/error_handling/error_codes.py with standardized error codes
for the entire backend. Migrates billing and license API routers as the
initial adoption. Updates AGENTS.md to mandate usage going forward.
2026-03-03 10:30:39 -08:00
36 changed files with 613 additions and 522 deletions

View File

@@ -15,7 +15,6 @@ permissions:
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
secrets: inherit
permissions:
contents: read
id-token: write

View File

@@ -617,6 +617,46 @@ Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Error Handling
**Always use `OnyxErrorCode` from `onyx.error_handling.error_codes` when raising `HTTPException`. Never hardcode
status codes or use `starlette.status` / `fastapi.status` constants directly.**
**Reason:** Standardized error codes give API consumers a stable, machine-readable `error_code` field to match on,
and keep HTTP status codes consistent across the entire backend.
```python
from fastapi import HTTPException
from onyx.error_handling.error_codes import OnyxErrorCode
# ✅ Good
raise HTTPException(
status_code=OnyxErrorCode.NOT_FOUND.status_code,
detail=OnyxErrorCode.NOT_FOUND.detail("Session not found"),
)
# ✅ Good — no extra message needed
raise HTTPException(
status_code=OnyxErrorCode.UNAUTHENTICATED.status_code,
detail=OnyxErrorCode.UNAUTHENTICATED.detail(),
)
# ❌ Bad — hardcoded status code
raise HTTPException(status_code=404, detail="Session not found")
# ❌ Bad — starlette constant
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
```
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
category is needed, add it there first — do not invent ad-hoc codes.
**Exception — upstream service errors:** When forwarding errors from an upstream service (e.g.
`BillingServiceError`), the HTTP status code is dynamic and comes from the upstream response.
In these cases, use `error_code: "BILLING_SERVICE_ERROR"` (or similar) with the upstream message
wrapped in the standard `{"error_code": "...", "message": "..."}` dict shape. Do not map these
to `OnyxErrorCode` members since the status code is not fixed.
## Best Practices
In addition to the other content in this file, best practices for contributing

View File

@@ -58,6 +58,7 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.engine.sql_engine import get_session
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -170,9 +171,11 @@ async def create_checkout_session(
used_seats = get_used_seats(tenant_id)
if seats < used_seats:
raise HTTPException(
status_code=400,
detail=f"Cannot subscribe with fewer seats than current usage. "
f"You have {used_seats} active users/integrations but requested {seats} seats.",
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
f"Cannot subscribe with fewer seats than current usage. "
f"You have {used_seats} active users/integrations but requested {seats} seats."
),
)
# Build redirect URL for after checkout completion
@@ -188,7 +191,11 @@ async def create_checkout_session(
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
# Preserve upstream status code; wrap message in structured format
raise HTTPException(
status_code=e.status_code,
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
)
@router.post("/create-customer-portal-session")
@@ -206,7 +213,10 @@ async def create_customer_portal_session(
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise HTTPException(status_code=400, detail="No license found")
raise HTTPException(
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail("No license found"),
)
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
@@ -217,7 +227,11 @@ async def create_customer_portal_session(
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
# Preserve upstream status code; wrap message in structured format
raise HTTPException(
status_code=e.status_code,
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
)
@router.get("/billing-information")
@@ -241,8 +255,10 @@ async def get_billing_information(
# Check circuit breaker (self-hosted only)
if _is_billing_circuit_open():
raise HTTPException(
status_code=503,
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
status_code=OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
detail=OnyxErrorCode.SERVICE_UNAVAILABLE.detail(
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry."
),
)
try:
@@ -254,7 +270,11 @@ async def get_billing_information(
# Open circuit breaker on connection failures (self-hosted only)
if e.status_code in (502, 503, 504):
_open_billing_circuit()
raise HTTPException(status_code=e.status_code, detail=e.message)
# Preserve upstream status code; wrap message in structured format
raise HTTPException(
status_code=e.status_code,
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
)
@router.post("/seats/update")
@@ -274,15 +294,20 @@ async def update_seats(
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise HTTPException(status_code=400, detail="No license found")
raise HTTPException(
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail("No license found"),
)
# Validate that new seat count is not less than current used seats
used_seats = get_used_seats(tenant_id)
if request.new_seat_count < used_seats:
raise HTTPException(
status_code=400,
detail=f"Cannot reduce seats below current usage. "
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
f"Cannot reduce seats below current usage. "
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats."
),
)
try:
@@ -298,7 +323,11 @@ async def update_seats(
return result
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
# Preserve upstream status code; wrap message in structured format
raise HTTPException(
status_code=e.status_code,
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
)
@router.get("/stripe-publishable-key")
@@ -330,8 +359,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Invalid Stripe publishable key format"
),
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
@@ -339,8 +370,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Stripe publishable key is not configured"
),
)
try:
@@ -352,16 +385,20 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Invalid Stripe publishable key format"
),
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Failed to fetch Stripe publishable key"
),
)

View File

@@ -35,6 +35,7 @@ from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.db.engine.sql_engine import get_session
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -128,8 +129,10 @@ async def claim_license(
"""
if MULTI_TENANT:
raise HTTPException(
status_code=400,
detail="License claiming is only available for self-hosted deployments",
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
"License claiming is only available for self-hosted deployments"
),
)
try:
@@ -147,14 +150,19 @@ async def claim_license(
metadata = get_license_metadata(db_session)
if not metadata or not metadata.tenant_id:
raise HTTPException(
status_code=400,
detail="No license found. Provide session_id after checkout.",
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
"No license found. Provide session_id after checkout."
),
)
license_row = get_license(db_session)
if not license_row or not license_row.license_data:
raise HTTPException(
status_code=400, detail="No license found in database"
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
"No license found in database"
),
)
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
@@ -173,7 +181,10 @@ async def claim_license(
license_data = data.get("license")
if not license_data:
raise HTTPException(status_code=404, detail="No license in response")
raise HTTPException(
status_code=OnyxErrorCode.NOT_FOUND.status_code,
detail=OnyxErrorCode.NOT_FOUND.detail("No license in response"),
)
# Verify signature before persisting
payload = verify_license_signature(license_data)
@@ -201,10 +212,16 @@ async def claim_license(
pass
raise HTTPException(status_code=status_code, detail=detail)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(str(e)),
)
except requests.RequestException:
raise HTTPException(
status_code=502, detail="Failed to connect to license server"
status_code=OnyxErrorCode.BAD_GATEWAY.status_code,
detail=OnyxErrorCode.BAD_GATEWAY.detail(
"Failed to connect to license server"
),
)
@@ -222,8 +239,10 @@ async def upload_license(
"""
if MULTI_TENANT:
raise HTTPException(
status_code=400,
detail="License upload is only available for self-hosted deployments",
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
"License upload is only available for self-hosted deployments"
),
)
try:
@@ -234,14 +253,20 @@ async def upload_license(
# Remove any stray whitespace/newlines from user input
license_data = license_data.strip()
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Invalid license file format")
raise HTTPException(
status_code=OnyxErrorCode.INVALID_INPUT.status_code,
detail=OnyxErrorCode.INVALID_INPUT.detail("Invalid license file format"),
)
# Verify cryptographic signature - this is the only validation needed
# The license's tenant_id identifies the customer in control plane, not locally
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(str(e)),
)
# Persist to DB and update cache
upsert_license(db_session, license_data)
@@ -298,8 +323,10 @@ async def delete_license(
"""
if MULTI_TENANT:
raise HTTPException(
status_code=400,
detail="License deletion is only available for self-hosted deployments",
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
"License deletion is only available for self-hosted deployments"
),
)
try:

View File

@@ -43,6 +43,7 @@ from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@@ -116,9 +117,14 @@ async def create_customer_portal_session(
try:
portal_url = fetch_customer_portal_session(tenant_id, return_url)
return {"stripe_customer_portal_url": portal_url}
except Exception as e:
except Exception:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Failed to create customer portal session"
),
)
@router.post("/create-checkout-session")
@@ -134,9 +140,14 @@ async def create_checkout_session(
try:
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
return {"stripe_checkout_url": checkout_url}
except Exception as e:
except Exception:
logger.exception("Failed to create checkout session")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Failed to create checkout session"
),
)
@router.post("/create-subscription-session")
@@ -147,15 +158,25 @@ async def create_subscription_session(
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
raise HTTPException(
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
detail=OnyxErrorCode.VALIDATION_ERROR.detail("Tenant ID not found"),
)
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
except HTTPException:
raise
except Exception:
logger.exception("Failed to create subscription session")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Failed to create subscription session"
),
)
@router.get("/stripe-publishable-key")
@@ -187,8 +208,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Invalid Stripe publishable key format"
),
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
@@ -196,8 +219,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Stripe publishable key is not configured"
),
)
try:
@@ -209,14 +234,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Invalid Stripe publishable key format"
),
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
"Failed to fetch Stripe publishable key"
),
)

View File

@@ -819,9 +819,7 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
# Tool Configs
#####
# Code Interpreter Service Configuration
CODE_INTERPRETER_BASE_URL = os.environ.get(
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
)
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000

View File

@@ -532,7 +532,6 @@ def fetch_default_model(
) -> ModelConfiguration | None:
model_config = db_session.scalar(
select(ModelConfiguration)
.options(selectinload(ModelConfiguration.llm_provider))
.join(LLMModelFlow)
.where(
ModelConfiguration.is_visible == True, # noqa: E712

View File

@@ -52,7 +52,7 @@ def create_user_files(
) -> CategorizedFilesResult:
# Categorize the files
categorized_files = categorize_uploaded_files(files, db_session)
categorized_files = categorize_uploaded_files(files)
# NOTE: At the moment, zip metadata is not used for user files.
# Should revisit to decide whether this should be a feature.
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)

View File

View File

@@ -0,0 +1,103 @@
"""
Standardized error codes for the Onyx backend.
Usage:
from onyx.error_handling.error_codes import OnyxErrorCode
raise HTTPException(
status_code=OnyxErrorCode.UNAUTHENTICATED.status_code,
detail=OnyxErrorCode.UNAUTHENTICATED.detail("Token expired"),
)
"""
from enum import Enum
class OnyxErrorCode(Enum):
"""
Each member is a tuple of (error_code_string, http_status_code).
The error_code_string is a stable, machine-readable identifier that
API consumers can match on. The http_status_code is the default HTTP
status to return.
"""
# ------------------------------------------------------------------
# Authentication (401)
# ------------------------------------------------------------------
UNAUTHENTICATED = ("UNAUTHENTICATED", 401)
INVALID_TOKEN = ("INVALID_TOKEN", 401)
TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401)
CSRF_FAILURE = ("CSRF_FAILURE", 403)
# ------------------------------------------------------------------
# Authorization (403)
# ------------------------------------------------------------------
UNAUTHORIZED = ("UNAUTHORIZED", 403)
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
ADMIN_ONLY = ("ADMIN_ONLY", 403)
EE_REQUIRED = ("EE_REQUIRED", 403)
# ------------------------------------------------------------------
# Validation / Bad Request (400)
# ------------------------------------------------------------------
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
# ------------------------------------------------------------------
# Not Found (404)
# ------------------------------------------------------------------
NOT_FOUND = ("NOT_FOUND", 404)
CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404)
CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404)
PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404)
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
# ------------------------------------------------------------------
# Conflict (409)
# ------------------------------------------------------------------
CONFLICT = ("CONFLICT", 409)
DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409)
# ------------------------------------------------------------------
# Rate Limiting / Quotas (429 / 402)
# ------------------------------------------------------------------
RATE_LIMITED = ("RATE_LIMITED", 429)
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
# ------------------------------------------------------------------
# Connector / Credential Errors (400-range)
# ------------------------------------------------------------------
CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400)
CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400)
CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401)
# ------------------------------------------------------------------
# Server Errors (5xx)
# ------------------------------------------------------------------
INTERNAL_ERROR = ("INTERNAL_ERROR", 500)
NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501)
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
BAD_GATEWAY = ("BAD_GATEWAY", 502)
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
def __init__(self, code: str, status_code: int) -> None:
self.code = code
self.status_code = status_code
def detail(self, message: str | None = None) -> dict[str, str]:
"""Build a structured error detail dict.
Returns a dict like:
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
If no message is supplied, the error code itself is used as the message.
"""
return {
"error_code": self.code,
"message": message or self.code,
}

View File

@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
pass
def categorize_uploaded_files(
files: list[UploadFile], db_session: Session
) -> CategorizedFiles:
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
"""
Categorize uploaded files based on text extractability and tokenized length.
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
"""
results = CategorizedFiles()
default_model = fetch_default_llm_model(db_session)
llm = get_default_llm()
model_name = default_model.name if default_model else None
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
tokenizer = get_tokenizer(
model_name=llm.config.model_name, provider_type=llm.config.model_provider
)
# Check if threshold checks should be skipped
skip_threshold = False

View File

@@ -1,5 +1,4 @@
import json
import time
from collections.abc import Generator
from typing import Literal
from typing import TypedDict
@@ -13,9 +12,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_HEALTH_CACHE_TTL_SECONDS = 30
_health_cache: dict[str, tuple[float, bool]] = {}
class FileInput(TypedDict):
"""Input file to be staged in execution workspace"""
@@ -102,32 +98,16 @@ class CodeInterpreterClient:
payload["files"] = files
return payload
def health(self, use_cache: bool = False) -> bool:
"""Check if the Code Interpreter service is healthy
Args:
use_cache: When True, return a cached result if available and
within the TTL window. The cache is always populated
after a live request regardless of this flag.
"""
if use_cache:
cached = _health_cache.get(self.base_url)
if cached is not None:
cached_at, cached_result = cached
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
return cached_result
def health(self) -> bool:
"""Check if the Code Interpreter service is healthy"""
url = f"{self.base_url}/health"
try:
response = self.session.get(url, timeout=5)
response.raise_for_status()
result = response.json().get("status") == "ok"
return response.json().get("status") == "ok"
except Exception as e:
logger.warning(f"Exception caught when checking health, e={e}")
result = False
_health_cache[self.base_url] = (time.monotonic(), result)
return result
return False
def execute(
self,

View File

@@ -107,11 +107,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
if not CODE_INTERPRETER_BASE_URL:
return False
server = fetch_code_interpreter_server(db_session)
if not server.server_enabled:
return False
client = CodeInterpreterClient()
return client.health(use_cache=True)
return server.server_enabled
def tool_definition(self) -> dict:
return {

View File

@@ -1027,13 +1027,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
else:
self._respond_json(404, {"error": "not found"})
def do_GET(self) -> None:
self._capture("GET", b"")
if self.path == "/health":
self._respond_json(200, {"status": "ok"})
else:
self._respond_json(404, {"error": "not found"})
def do_DELETE(self) -> None:
self._capture("DELETE", b"")
self.send_response(200)
@@ -1114,14 +1107,6 @@ def mock_ci_server() -> Generator[MockCodeInterpreterServer, None, None]:
server.shutdown()
@pytest.fixture(autouse=True)
def _clear_health_cache() -> None:
"""Reset the health check cache before every test."""
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
mod._health_cache = {}
@pytest.fixture()
def _attach_python_tool_to_default_persona(db_session: Session) -> None:
"""Ensure the default persona (id=0) has the PythonTool attached."""

View File

@@ -103,7 +103,9 @@ class TestCreateCheckoutSession:
)
assert exc_info.value.status_code == 502
assert "Stripe error" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert exc_info.value.detail["error_code"] == "BILLING_SERVICE_ERROR"
assert exc_info.value.detail["message"] == "Stripe error"
class TestCreateCustomerPortalSession:
@@ -134,7 +136,8 @@ class TestCreateCustomerPortalSession:
)
assert exc_info.value.status_code == 400
assert "No license found" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert exc_info.value.detail["message"] == "No license found"
@pytest.mark.asyncio
@patch("ee.onyx.server.billing.api.create_portal_service")
@@ -241,7 +244,8 @@ class TestUpdateSeats:
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
assert exc_info.value.status_code == 400
assert "No license found" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert exc_info.value.detail["message"] == "No license found"
@pytest.mark.asyncio
@patch("ee.onyx.server.billing.api.get_used_seats")
@@ -314,7 +318,9 @@ class TestUpdateSeats:
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
assert exc_info.value.status_code == 400
assert "Cannot reduce below 10 seats" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert exc_info.value.detail["error_code"] == "BILLING_SERVICE_ERROR"
assert exc_info.value.detail["message"] == "Cannot reduce below 10 seats"
class TestCircuitBreaker:
@@ -344,7 +350,8 @@ class TestCircuitBreaker:
await get_billing_information(_=MagicMock(), db_session=MagicMock())
assert exc_info.value.status_code == 503
assert "Connect to Stripe" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert "Connect to Stripe" in exc_info.value.detail["message"]
@pytest.mark.asyncio
@patch("ee.onyx.server.billing.api.MULTI_TENANT", False)

View File

@@ -70,7 +70,10 @@ class TestGetStripePublishableKey:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "Invalid Stripe publishable key format" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert (
exc_info.value.detail["message"] == "Invalid Stripe publishable key format"
)
@pytest.mark.asyncio
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
@@ -96,7 +99,10 @@ class TestGetStripePublishableKey:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "Invalid Stripe publishable key format" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert (
exc_info.value.detail["message"] == "Invalid Stripe publishable key format"
)
@pytest.mark.asyncio
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
@@ -118,7 +124,10 @@ class TestGetStripePublishableKey:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "Failed to fetch Stripe publishable key" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert (
exc_info.value.detail["message"] == "Failed to fetch Stripe publishable key"
)
@pytest.mark.asyncio
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
@@ -133,7 +142,8 @@ class TestGetStripePublishableKey:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "not configured" in exc_info.value.detail
assert isinstance(exc_info.value.detail, dict)
assert "not configured" in exc_info.value.detail["message"]
@pytest.mark.asyncio
@patch(

View File

@@ -1,37 +1,25 @@
"""Tests for PythonTool availability based on server_enabled flag and health check.
"""Tests for PythonTool availability based on server_enabled flag.
Verifies that PythonTool reports itself as unavailable when either:
- CODE_INTERPRETER_BASE_URL is not set, or
- CodeInterpreterServer.server_enabled is False in the database, or
- The Code Interpreter service health check fails.
Also verifies that the health check result is cached with a TTL.
- CodeInterpreterServer.server_enabled is False in the database.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
CLIENT_MODULE = "onyx.tools.tool_implementations.python.code_interpreter_client"
@pytest.fixture(autouse=True)
def _clear_health_cache() -> None:
"""Reset the health check cache before every test."""
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
mod._health_cache = {}
# ------------------------------------------------------------------
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
# ------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", None)
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
None,
)
def test_python_tool_unavailable_without_base_url() -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
@@ -39,7 +27,10 @@ def test_python_tool_unavailable_without_base_url() -> None:
assert PythonTool.is_available(db_session) is False
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "")
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"",
)
def test_python_tool_unavailable_with_empty_base_url() -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
@@ -52,8 +43,13 @@ def test_python_tool_unavailable_with_empty_base_url() -> None:
# ------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://localhost:8000",
)
@patch(
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
)
def test_python_tool_unavailable_when_server_disabled(
mock_fetch: MagicMock,
) -> None:
@@ -68,15 +64,18 @@ def test_python_tool_unavailable_when_server_disabled(
# ------------------------------------------------------------------
# Health check determines availability when URL + server are OK
# Available when both conditions are met
# ------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
def test_python_tool_available_when_health_check_passes(
mock_client_cls: MagicMock,
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://localhost:8000",
)
@patch(
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
)
def test_python_tool_available_when_server_enabled(
mock_fetch: MagicMock,
) -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
@@ -85,120 +84,5 @@ def test_python_tool_available_when_health_check_passes(
mock_server.server_enabled = True
mock_fetch.return_value = mock_server
mock_client = MagicMock()
mock_client.health.return_value = True
mock_client_cls.return_value = mock_client
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is True
mock_client.health.assert_called_once_with(use_cache=True)
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
def test_python_tool_unavailable_when_health_check_fails(
mock_client_cls: MagicMock,
mock_fetch: MagicMock,
) -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
mock_server = MagicMock()
mock_server.server_enabled = True
mock_fetch.return_value = mock_server
mock_client = MagicMock()
mock_client.health.return_value = False
mock_client_cls.return_value = mock_client
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is False
mock_client.health.assert_called_once_with(use_cache=True)
# ------------------------------------------------------------------
# Health check is NOT reached when preconditions fail
# ------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
def test_health_check_not_called_when_server_disabled(
mock_client_cls: MagicMock,
mock_fetch: MagicMock,
) -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
mock_server = MagicMock()
mock_server.server_enabled = False
mock_fetch.return_value = mock_server
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is False
mock_client_cls.assert_not_called()
# ------------------------------------------------------------------
# Health check caching (tested at the client level)
# ------------------------------------------------------------------
def test_health_check_cached_on_second_call() -> None:
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
client = CodeInterpreterClient(base_url="http://fake:9000")
mock_response = MagicMock()
mock_response.json.return_value = {"status": "ok"}
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
assert client.health(use_cache=True) is True
assert client.health(use_cache=True) is True
# Only one HTTP call — the second used the cache
mock_get.assert_called_once()
@patch(f"{CLIENT_MODULE}.time")
def test_health_check_refreshed_after_ttl_expires(mock_time: MagicMock) -> None:
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
_HEALTH_CACHE_TTL_SECONDS,
)
client = CodeInterpreterClient(base_url="http://fake:9000")
mock_response = MagicMock()
mock_response.json.return_value = {"status": "ok"}
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
# First call at t=0 — cache miss
mock_time.monotonic.return_value = 0.0
assert client.health(use_cache=True) is True
assert mock_get.call_count == 1
# Second call within TTL — cache hit
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS - 1)
assert client.health(use_cache=True) is True
assert mock_get.call_count == 1
# Third call after TTL — cache miss, fresh request
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS + 1)
assert client.health(use_cache=True) is True
assert mock_get.call_count == 2
def test_health_check_no_cache_by_default() -> None:
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
client = CodeInterpreterClient(base_url="http://fake:9000")
mock_response = MagicMock()
mock_response.json.return_value = {"status": "ok"}
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
assert client.health() is True
assert client.health() is True
# Both calls hit the network when use_cache=False (default)
assert mock_get.call_count == 2

View File

@@ -1,72 +0,0 @@
{
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 2,
"statusCheck": true,
"commentTypes": [
"logic",
"syntax",
"style"
],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": [
"dependabot[bot]",
"renovate[bot]"
],
"ignoreKeywords": "",
"ignorePatterns": "greptile.json\n",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": true,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"customContext": {
"other": [
{
"scope": [],
"content": ""
}
],
"rules": [
{
"scope": [],
"rule": ""
}
],
"files": [
{
"scope": [],
"path": "",
"description": ""
}
]
}
}

View File

@@ -23,7 +23,6 @@
"test:verbose": "jest --verbose",
"test:ci": "jest --ci --maxWorkers=2 --silent --bail",
"test:changed": "jest --onlyChanged",
"test:diff": "jest --changedSince=main",
"test:debug": "node --inspect-brk node_modules/.bin/jest --runInBand"
},
"dependencies": {

View File

@@ -187,7 +187,6 @@ export const fetchOllamaModels = async (
api_base: apiBase,
provider_name: params.provider_name,
}),
signal: params.signal,
});
if (!response.ok) {

View File

@@ -20,7 +20,6 @@ import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/
import { Credential } from "@/lib/connectors/credentials";
import { useFederatedConnectors } from "@/lib/hooks";
import Text from "@/refresh-components/texts/Text";
import { useToastFromQuery } from "@/hooks/useToast";
export default function ConnectorWrapper({
connector,
@@ -30,13 +29,6 @@ export default function ConnectorWrapper({
const searchParams = useSearchParams();
const mode = searchParams?.get("mode"); // 'federated' or 'regular'
useToastFromQuery({
oauth_failed: {
message: "OAuth authentication failed. Please try again.",
type: "error",
},
});
// Check if the connector is valid
if (!isValidSource(connector)) {
return (

View File

@@ -2,6 +2,10 @@ import { getDomain } from "@/lib/redirectSS";
import { buildUrl } from "@/lib/utilsSS";
import { NextRequest, NextResponse } from "next/server";
import { cookies } from "next/headers";
import {
GMAIL_AUTH_IS_ADMIN_COOKIE_NAME,
GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME,
} from "@/lib/constants";
import {
CRAFT_OAUTH_COOKIE_NAME,
CRAFT_CONFIGURE_PATH,
@@ -11,7 +15,6 @@ import { processCookies } from "@/lib/userSS";
export const GET = async (request: NextRequest) => {
const requestCookies = await cookies();
const connector = request.url.includes("gmail") ? "gmail" : "google-drive";
const callbackEndpoint = `/manage/connector/${connector}/callback`;
const url = new URL(buildUrl(callbackEndpoint));
url.search = request.nextUrl.search;
@@ -23,12 +26,7 @@ export const GET = async (request: NextRequest) => {
});
if (!response.ok) {
return NextResponse.redirect(
new URL(
`/admin/connectors/${connector}?message=oauth_failed`,
getDomain(request)
)
);
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
}
// Check for build mode OAuth flag (redirects to build admin panel)
@@ -42,7 +40,16 @@ export const GET = async (request: NextRequest) => {
return redirectResponse;
}
return NextResponse.redirect(
new URL(`/admin/connectors/${connector}`, getDomain(request))
);
const authCookieName =
connector === "gmail"
? GMAIL_AUTH_IS_ADMIN_COOKIE_NAME
: GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME;
if (requestCookies.get(authCookieName)?.value?.toLowerCase() === "true") {
return NextResponse.redirect(
new URL(`/admin/connectors/${connector}`, getDomain(request))
);
}
return NextResponse.redirect(new URL("/user/connectors", getDomain(request)));
};

View File

@@ -1,7 +1,7 @@
"use client";
import { useEffect, useState } from "react";
import { usePathname, useSearchParams } from "next/navigation";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { AdminPageTitle } from "@/components/admin/Title";
import { getSourceMetadata, isValidSource } from "@/lib/sources";
import { ValidSources } from "@/lib/types";
@@ -9,6 +9,7 @@ import CardSection from "@/components/admin/CardSection";
import { handleOAuthAuthorizationResponse } from "@/lib/oauth_utils";
import { SvgKey } from "@opal/icons";
export default function OAuthCallbackPage() {
const router = useRouter();
const searchParams = useSearchParams();
const [statusMessage, setStatusMessage] = useState("Processing...");

View File

@@ -6,7 +6,11 @@ import { useRouter } from "next/navigation";
import type { Route } from "next";
import { adminDeleteCredential } from "@/lib/credential";
import { setupGoogleDriveOAuth } from "@/lib/googleDrive";
import { DOCS_ADMINS_PATH } from "@/lib/constants";
import {
DOCS_ADMINS_PATH,
GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME,
} from "@/lib/constants";
import Cookies from "js-cookie";
import { TextFormField, SectionHeader } from "@/components/Field";
import { Form, Formik } from "formik";
import { User } from "@/lib/types";
@@ -588,6 +592,11 @@ export const DriveAuthSection = ({
onClick={async () => {
setIsAuthenticating(true);
try {
// cookie used by callback to determine where to finally redirect to
Cookies.set(GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
path: "/",
});
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
isAdmin: true,
name: "OAuth (uploaded)",

View File

@@ -7,7 +7,10 @@ import { useRouter } from "next/navigation";
import type { Route } from "next";
import { adminDeleteCredential } from "@/lib/credential";
import { setupGmailOAuth } from "@/lib/gmail";
import { DOCS_ADMINS_PATH } from "@/lib/constants";
import {
DOCS_ADMINS_PATH,
GMAIL_AUTH_IS_ADMIN_COOKIE_NAME,
} from "@/lib/constants";
import { CRAFT_OAUTH_COOKIE_NAME } from "@/app/craft/v1/constants";
import Cookies from "js-cookie";
import { TextFormField, SectionHeader } from "@/components/Field";
@@ -599,6 +602,9 @@ export const GmailAuthSection = ({
onClick={async () => {
setIsAuthenticating(true);
try {
Cookies.set(GMAIL_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
path: "/",
});
if (buildMode) {
Cookies.set(CRAFT_OAUTH_COOKIE_NAME, "true", {
path: "/",

View File

@@ -0,0 +1,76 @@
import { SubLabel } from "@/components/Field";
import { toast } from "@/hooks/useToast";
import { useEffect, useState } from "react";
import Dropzone from "react-dropzone";
export function ImageUpload({
selectedFile,
setSelectedFile,
}: {
selectedFile: File | null;
setSelectedFile: (file: File) => void;
}) {
const [tmpImageUrl, setTmpImageUrl] = useState<string>("");
useEffect(() => {
if (selectedFile) {
setTmpImageUrl(URL.createObjectURL(selectedFile));
} else {
setTmpImageUrl("");
}
}, [selectedFile]);
const [dragActive, setDragActive] = useState(false);
return (
<Dropzone
onDrop={(acceptedFiles) => {
if (acceptedFiles.length !== 1) {
toast.error("Only one file can be uploaded at a time");
return;
}
const acceptedFile = acceptedFiles[0];
if (acceptedFile === undefined) {
toast.error("acceptedFile cannot be undefined");
return;
}
setTmpImageUrl(URL.createObjectURL(acceptedFile));
setSelectedFile(acceptedFile);
setDragActive(false);
}}
onDragLeave={() => setDragActive(false)}
onDragEnter={() => setDragActive(true)}
>
{({ getRootProps, getInputProps }) => (
<section>
<div
{...getRootProps()}
className={
"flex flex-col items-center w-full px-4 py-12 rounded " +
"shadow-lg tracking-wide border border-border cursor-pointer" +
(dragActive ? " border-accent" : "")
}
>
<input {...getInputProps()} />
<b className="text-text-darker">
Drag and drop a .png or .jpg file, or click to select a file!
</b>
</div>
{tmpImageUrl && (
<div className="mt-4 mb-8">
<SubLabel>Uploaded Image:</SubLabel>
<img
alt="Uploaded Image"
src={tmpImageUrl}
className="mt-4 max-w-xs max-h-64"
/>
</div>
)}
</section>
)}
</Dropzone>
);
}

View File

@@ -16,6 +16,8 @@ import { cn } from "@/lib/utils";
const CsvContent: React.FC<ContentComponentProps> = ({
fileDescriptor,
isLoading,
fadeIn,
expanded = false,
}) => {
const [data, setData] = useState<Record<string, string>[]>([]);
@@ -92,7 +94,7 @@ const CsvContent: React.FC<ContentComponentProps> = ({
}
};
if (isFetching) {
if (isLoading || isFetching) {
return (
<div className="flex items-center justify-center h-[300px]">
<SimpleLoader />

View File

@@ -1,5 +1,5 @@
// ExpandableContentWrapper
import React, { useState } from "react";
import React, { useState, useEffect } from "react";
import { SvgDownloadCloud, SvgFold, SvgMaximize2, SvgX } from "@opal/icons";
import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card";
import { Button } from "@opal/components";
@@ -17,6 +17,8 @@ export interface ExpandableContentWrapperProps {
export interface ContentComponentProps {
fileDescriptor: FileDescriptor;
isLoading: boolean;
fadeIn: boolean;
expanded?: boolean;
}
@@ -26,9 +28,24 @@ export default function ExpandableContentWrapper({
ContentComponent,
}: ExpandableContentWrapperProps) {
const [expanded, setExpanded] = useState(false);
const [isLoading, setIsLoading] = useState(true);
const [fadeIn, setFadeIn] = useState(false);
const toggleExpand = () => setExpanded((prev) => !prev);
// Prevent a jarring fade in
useEffect(() => {
setTimeout(() => setIsLoading(false), 300);
}, []);
useEffect(() => {
if (!isLoading) {
setTimeout(() => setFadeIn(true), 50);
} else {
setFadeIn(false);
}
}, [isLoading]);
const downloadFile = () => {
const a = document.createElement("a");
a.href = `api/chat/file/${fileDescriptor.id}`;
@@ -86,6 +103,8 @@ export default function ExpandableContentWrapper({
{!expanded && (
<ContentComponent
fileDescriptor={fileDescriptor}
isLoading={isLoading}
fadeIn={fadeIn}
expanded={expanded}
/>
)}

View File

@@ -118,7 +118,6 @@ export interface BedrockFetchParams {
export interface OllamaFetchParams {
api_base?: string;
provider_name?: string;
signal?: AbortSignal;
}
export interface OpenRouterFetchParams {

View File

@@ -28,6 +28,11 @@ export const NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED =
export const TENANT_ID_COOKIE_NAME = "onyx_tid";
export const GMAIL_AUTH_IS_ADMIN_COOKIE_NAME = "gmail_auth_is_admin";
export const GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME =
"google_drive_auth_is_admin";
export const SEARCH_TYPE_COOKIE_NAME = "search_type";
export const AGENTIC_SEARCH_TYPE_COOKIE_NAME = "agentic_type";

View File

@@ -9,7 +9,6 @@ import { Button as OpalButton } from "@opal/components";
import InputAvatar from "@/refresh-components/inputs/InputAvatar";
import { cn } from "@/lib/utils";
import { SvgCheckCircle, SvgEdit, SvgUser, SvgX } from "@opal/icons";
import { ContentAction } from "@opal/layouts";
export default function NonAdminStep() {
const inputRef = useRef<HTMLInputElement>(null);
@@ -55,26 +54,17 @@ export default function NonAdminStep() {
className="flex items-center justify-between w-full min-h-11 py-1 pl-3 pr-2 bg-background-tint-00 rounded-16 shadow-01 mb-2"
aria-label="non-admin-confirmation"
>
<ContentAction
icon={({ className, ...props }) => (
<SvgCheckCircle
className={cn(className, "stroke-status-success-05")}
{...props}
/>
)}
title="You're all set!"
sizePreset="main-ui"
variant="body"
prominence="muted"
paddingVariant="fit"
rightChildren={
<OpalButton
prominence="tertiary"
size="sm"
icon={SvgX}
onClick={handleDismissConfirmation}
/>
}
<div className="flex items-center gap-1">
<SvgCheckCircle className="w-4 h-4 stroke-status-success-05" />
<Text as="p" text03 mainUiBody>
You're all set!
</Text>
</div>
<OpalButton
prominence="tertiary"
size="sm"
icon={SvgX}
onClick={handleDismissConfirmation}
/>
</div>
)}
@@ -85,36 +75,39 @@ export default function NonAdminStep() {
role="group"
aria-label="non-admin-name-prompt"
>
<ContentAction
icon={SvgUser}
title="What should Onyx call you?"
description="We will display this name in the app."
sizePreset="main-ui"
variant="section"
paddingVariant="fit"
rightChildren={
<div className="flex items-center justify-end gap-2">
<InputTypeIn
ref={inputRef}
placeholder="Your name"
value={name || ""}
onChange={(e: React.ChangeEvent<HTMLInputElement>) =>
setName(e.target.value)
}
onKeyDown={(e) => {
if (e.key === "Enter" && name && name.trim().length > 0) {
e.preventDefault();
handleSave();
}
}}
className="w-[26%] min-w-40"
/>
<Button disabled={name === ""} onClick={handleSave}>
Save
</Button>
</div>
}
/>
<div className="flex items-center gap-1 h-full">
<div className="h-full p-0.5">
<SvgUser className="w-4 h-4 stroke-text-03" />
</div>
<div>
<Text as="p" text04 mainUiAction>
What should Onyx call you?
</Text>
<Text as="p" text03 secondaryBody>
We will display this name in the app.
</Text>
</div>
</div>
<div className="flex items-center justify-end gap-2">
<InputTypeIn
ref={inputRef}
placeholder="Your name"
value={name || ""}
onChange={(e: React.ChangeEvent<HTMLInputElement>) =>
setName(e.target.value)
}
onKeyDown={(e) => {
if (e.key === "Enter" && name && name.trim().length > 0) {
e.preventDefault();
handleSave();
}
}}
className="w-[26%] min-w-40"
/>
<Button disabled={name === ""} onClick={handleSave}>
Save
</Button>
</div>
</div>
) : (
<div

View File

@@ -1,5 +1,3 @@
"use client";
import { memo, useState, useCallback } from "react";
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
@@ -14,7 +12,6 @@ import {
import { Disabled } from "@/refresh-components/Disabled";
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
import { SvgCheckCircle, SvgCpu, SvgExternalLink } from "@opal/icons";
import { ContentAction } from "@opal/layouts";
type LLMStepProps = {
state: OnboardingState;
@@ -125,14 +122,21 @@ const LLMStepInner = ({
className="flex flex-col items-center justify-between w-full p-1 rounded-16 border border-border-01 bg-background-tint-00"
aria-label="onboarding-llm-step"
>
<ContentAction
icon={SvgCpu}
title="Connect your LLM models"
description="Onyx supports both self-hosted models and popular providers."
sizePreset="main-ui"
variant="section"
paddingVariant="lg"
rightChildren={
<div className="flex gap-2 justify-between h-full w-full">
<div className="flex mx-2 mt-2 gap-1">
<div className="h-full p-0.5">
<SvgCpu className="w-4 h-4 stroke-text-03" />
</div>
<div>
<Text as="p" text04 mainUiAction>
Connect your LLM models
</Text>
<Text as="p" text03 secondaryBody>
Onyx supports both self-hosted models and popular providers.
</Text>
</div>
</div>
<div className="p-0.5">
<Button
tertiary
rightIcon={SvgExternalLink}
@@ -141,8 +145,8 @@ const LLMStepInner = ({
>
View in Admin Panel
</Button>
}
/>
</div>
</div>
<Separator />
<div className="flex flex-wrap gap-1 [&>*:last-child:nth-child(odd)]:basis-full">
{isLoading ? (

View File

@@ -8,7 +8,6 @@ import InputAvatar from "@/refresh-components/inputs/InputAvatar";
import { cn } from "@/lib/utils";
import IconButton from "@/refresh-components/buttons/IconButton";
import { SvgCheckCircle, SvgEdit, SvgUser } from "@opal/icons";
import { ContentAction } from "@opal/layouts";
export interface NameStepProps {
state: OnboardingState;
@@ -41,23 +40,26 @@ const NameStep = React.memo(
role="group"
aria-label="onboarding-name-step"
>
<ContentAction
icon={SvgUser}
title="What should Onyx call you?"
description="We will display this name in the app."
sizePreset="main-ui"
variant="section"
paddingVariant="fit"
rightChildren={
<InputTypeIn
ref={inputRef}
placeholder="Your name"
value={userName || ""}
onChange={(e) => updateName(e.target.value)}
onKeyDown={handleKeyDown}
className="max-w-60"
/>
}
<div className="flex items-center gap-1 h-full">
<div className="h-full p-0.5">
<SvgUser className="w-4 h-4 stroke-text-03" />
</div>
<div>
<Text as="p" text04 mainUiAction>
What should Onyx call you?
</Text>
<Text as="p" text03 secondaryBody>
We will display this name in the app.
</Text>
</div>
</div>
<InputTypeIn
ref={inputRef}
placeholder="Your name"
value={userName || ""}
onChange={(e) => updateName(e.target.value)}
onKeyDown={handleKeyDown}
className="max-w-60"
/>
</div>
) : (

View File

@@ -25,7 +25,7 @@ export interface ActionItemProps {
disabled: boolean;
isForced: boolean;
isUnavailable?: boolean;
tooltip?: string;
unavailableReason?: string;
showAdminConfigure?: boolean;
adminConfigureHref?: string;
adminConfigureTooltip?: string;
@@ -47,7 +47,7 @@ export default function ActionLineItem({
disabled,
isForced,
isUnavailable = false,
tooltip,
unavailableReason,
showAdminConfigure = false,
adminConfigureHref,
adminConfigureTooltip = "Configure",
@@ -88,7 +88,7 @@ export default function ActionLineItem({
sourceCounts.enabled > 0 &&
sourceCounts.enabled < sourceCounts.total;
const tooltipText = tooltip || tool?.description;
const tooltipText = isUnavailable ? unavailableReason : tool?.description;
return (
<SimpleTooltip tooltip={tooltipText} className="max-w-[30rem]">

View File

@@ -42,45 +42,26 @@ import { useProjectsContext } from "@/providers/ProjectsContext";
import { SvgActions, SvgChevronRight, SvgKey, SvgSliders } from "@opal/icons";
import { Button } from "@opal/components";
function buildTooltipMessage(
actionDescription: string,
isConfigured: boolean,
isAdmin: boolean
) {
const _ADMIN_CONFIGURE_MESSAGE = "Press the settings cog to enable.";
const _USER_NOT_ADMIN_MESSAGE = "Ask an admin to configure.";
if (isConfigured) {
return actionDescription;
}
if (isAdmin) {
return actionDescription + " " + _ADMIN_CONFIGURE_MESSAGE;
}
return actionDescription + " " + _USER_NOT_ADMIN_MESSAGE;
}
const TOOL_DESCRIPTIONS: Record<string, string> = {
[SEARCH_TOOL_ID]: "Search through connected knowledge to inform the answer.",
[IMAGE_GENERATION_TOOL_ID]: "Generate images based on a prompt.",
[WEB_SEARCH_TOOL_ID]: "Search the web for up-to-date information.",
[PYTHON_TOOL_ID]: "Execute code for complex analysis.",
const UNAVAILABLE_TOOL_TOOLTIP_FALLBACK =
"This action is not configured yet. Ask an admin to enable it.";
const UNAVAILABLE_TOOL_TOOLTIP_ADMIN_FALLBACK =
"This action is not configured yet. If you have access, enable it in the admin panel.";
const UNAVAILABLE_TOOL_TOOLTIPS: Record<string, string> = {
[IMAGE_GENERATION_TOOL_ID]:
"Image generation requires a configured model. If you have access, set one up under Settings > Image Generation, or ask an admin.",
[WEB_SEARCH_TOOL_ID]:
"Web search requires a configured provider. If you have access, set one up under Settings > Web Search, or ask an admin.",
[PYTHON_TOOL_ID]:
"Code Interpreter requires the service to be configured with a valid base URL. If you have access, configure it in the admin panel, or ask an admin.",
};
const DEFAULT_TOOL_DESCRIPTION = "This action is not configured yet.";
function getToolTooltip(
tool: ToolSnapshot,
isConfigured: boolean,
isAdmin: boolean
): string {
const description =
(tool.in_code_tool_id && TOOL_DESCRIPTIONS[tool.in_code_tool_id]) ||
tool.description ||
DEFAULT_TOOL_DESCRIPTION;
return buildTooltipMessage(description, isConfigured, isAdmin);
}
const getUnavailableToolTooltip = (
inCodeToolId?: string | null,
canAdminConfigure?: boolean
) =>
(inCodeToolId && UNAVAILABLE_TOOL_TOOLTIPS[inCodeToolId]) ??
(canAdminConfigure
? UNAVAILABLE_TOOL_TOOLTIP_ADMIN_FALLBACK
: UNAVAILABLE_TOOL_TOOLTIP_FALLBACK);
const ADMIN_CONFIG_LINKS: Record<string, { href: string; tooltip: string }> = {
[IMAGE_GENERATION_TOOL_ID]: {
@@ -91,10 +72,6 @@ const ADMIN_CONFIG_LINKS: Record<string, { href: string; tooltip: string }> = {
href: "/admin/configuration/web-search",
tooltip: "Configure Web Search",
},
[PYTHON_TOOL_ID]: {
href: "/admin/configuration/code-interpreter",
tooltip: "Configure Code Interpreter",
},
KnowledgeGraphTool: {
href: "/admin/kg",
tooltip: "Configure Knowledge Graph",
@@ -915,11 +892,14 @@ export default function ActionsPopover({
disabled={disabledToolIds.includes(tool.id)}
isForced={forcedToolIds.includes(tool.id)}
isUnavailable={isUnavailable}
tooltip={getToolTooltip(
tool,
isToolAvailable,
canAdminConfigure
)}
unavailableReason={
isUnavailable
? getUnavailableToolTooltip(
tool.in_code_tool_id,
canAdminConfigure
)
: undefined
}
showAdminConfigure={!!adminConfigureInfo}
adminConfigureHref={adminConfigureInfo?.href}
adminConfigureTooltip={adminConfigureInfo?.tooltip}

View File

@@ -23,9 +23,8 @@ import {
} from "./formUtils";
import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useEffect, useState } from "react";
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
import debounce from "lodash/debounce";
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
@@ -62,16 +61,14 @@ function OllamaModalContent({
}: OllamaModalContentProps) {
const [isLoadingModels, setIsLoadingModels] = useState(true);
const fetchModels = useCallback(
(apiBase: string, signal: AbortSignal) => {
useEffect(() => {
if (formikProps.values.api_base) {
setIsLoadingModels(true);
fetchOllamaModels({
api_base: apiBase,
api_base: formikProps.values.api_base,
provider_name: existingLlmProvider?.name,
signal,
})
.then((data) => {
if (signal.aborted) return;
if (data.error) {
console.error("Error fetching models:", data.error);
setFetchedModels([]);
@@ -80,32 +77,14 @@ function OllamaModalContent({
setFetchedModels(data.models);
})
.finally(() => {
if (!signal.aborted) {
setIsLoadingModels(false);
}
setIsLoadingModels(false);
});
},
[existingLlmProvider?.name, setFetchedModels]
);
const debouncedFetchModels = useMemo(
() => debounce(fetchModels, 500),
[fetchModels]
);
useEffect(() => {
if (formikProps.values.api_base) {
const controller = new AbortController();
debouncedFetchModels(formikProps.values.api_base, controller.signal);
return () => {
debouncedFetchModels.cancel();
controller.abort();
};
} else {
setIsLoadingModels(false);
setFetchedModels([]);
}
}, [formikProps.values.api_base, debouncedFetchModels, setFetchedModels]);
}, [
formikProps.values.api_base,
existingLlmProvider?.name,
setFetchedModels,
]);
const currentModels =
fetchedModels.length > 0