mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 22:55:46 +00:00
Compare commits
6 Commits
action_too
...
nikg/stand
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0cf3f22b3b | ||
|
|
77fd9bb052 | ||
|
|
42f723d17c | ||
|
|
d261f24578 | ||
|
|
b670ac69e8 | ||
|
|
31d95908c4 |
@@ -15,7 +15,6 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
40
AGENTS.md
40
AGENTS.md
@@ -617,6 +617,46 @@ Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Always use `OnyxErrorCode` from `onyx.error_handling.error_codes` when raising `HTTPException`. Never hardcode
|
||||
status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
**Reason:** Standardized error codes give API consumers a stable, machine-readable `error_code` field to match on,
|
||||
and keep HTTP status codes consistent across the entire backend.
|
||||
|
||||
```python
|
||||
from fastapi import HTTPException
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
|
||||
# ✅ Good
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.NOT_FOUND.status_code,
|
||||
detail=OnyxErrorCode.NOT_FOUND.detail("Session not found"),
|
||||
)
|
||||
|
||||
# ✅ Good — no extra message needed
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.UNAUTHENTICATED.status_code,
|
||||
detail=OnyxErrorCode.UNAUTHENTICATED.detail(),
|
||||
)
|
||||
|
||||
# ❌ Bad — hardcoded status code
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# ❌ Bad — starlette constant
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
```
|
||||
|
||||
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
|
||||
category is needed, add it there first — do not invent ad-hoc codes.
|
||||
|
||||
**Exception — upstream service errors:** When forwarding errors from an upstream service (e.g.
|
||||
`BillingServiceError`), the HTTP status code is dynamic and comes from the upstream response.
|
||||
In these cases, use `error_code: "BILLING_SERVICE_ERROR"` (or similar) with the upstream message
|
||||
wrapped in the standard `{"error_code": "...", "message": "..."}` dict shape. Do not map these
|
||||
to `OnyxErrorCode` members since the status code is not fixed.
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -170,9 +171,11 @@ async def create_checkout_session(
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if seats < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot subscribe with fewer seats than current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {seats} seats.",
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
|
||||
f"Cannot subscribe with fewer seats than current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {seats} seats."
|
||||
),
|
||||
)
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
@@ -188,7 +191,11 @@ async def create_checkout_session(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
# Preserve upstream status code; wrap message in structured format
|
||||
raise HTTPException(
|
||||
status_code=e.status_code,
|
||||
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -206,7 +213,10 @@ async def create_customer_portal_session(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail("No license found"),
|
||||
)
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
@@ -217,7 +227,11 @@ async def create_customer_portal_session(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
# Preserve upstream status code; wrap message in structured format
|
||||
raise HTTPException(
|
||||
status_code=e.status_code,
|
||||
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
@@ -241,8 +255,10 @@ async def get_billing_information(
|
||||
# Check circuit breaker (self-hosted only)
|
||||
if _is_billing_circuit_open():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
status_code=OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
|
||||
detail=OnyxErrorCode.SERVICE_UNAVAILABLE.detail(
|
||||
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -254,7 +270,11 @@ async def get_billing_information(
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
_open_billing_circuit()
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
# Preserve upstream status code; wrap message in structured format
|
||||
raise HTTPException(
|
||||
status_code=e.status_code,
|
||||
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
@@ -274,15 +294,20 @@ async def update_seats(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail("No license found"),
|
||||
)
|
||||
|
||||
# Validate that new seat count is not less than current used seats
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if request.new_seat_count < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reduce seats below current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
|
||||
f"Cannot reduce seats below current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -298,7 +323,11 @@ async def update_seats(
|
||||
|
||||
return result
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
# Preserve upstream status code; wrap message in structured format
|
||||
raise HTTPException(
|
||||
status_code=e.status_code,
|
||||
detail={"error_code": "BILLING_SERVICE_ERROR", "message": e.message},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -330,8 +359,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Invalid Stripe publishable key format"
|
||||
),
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
@@ -339,8 +370,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Stripe publishable key is not configured"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -352,16 +385,20 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Invalid Stripe publishable key format"
|
||||
),
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Failed to fetch Stripe publishable key"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -128,8 +129,10 @@ async def claim_license(
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
|
||||
"License claiming is only available for self-hosted deployments"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -147,14 +150,19 @@ async def claim_license(
|
||||
metadata = get_license_metadata(db_session)
|
||||
if not metadata or not metadata.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No license found. Provide session_id after checkout.",
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
|
||||
"No license found. Provide session_id after checkout."
|
||||
),
|
||||
)
|
||||
|
||||
license_row = get_license(db_session)
|
||||
if not license_row or not license_row.license_data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No license found in database"
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
|
||||
"No license found in database"
|
||||
),
|
||||
)
|
||||
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
|
||||
@@ -173,7 +181,10 @@ async def claim_license(
|
||||
license_data = data.get("license")
|
||||
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.NOT_FOUND.status_code,
|
||||
detail=OnyxErrorCode.NOT_FOUND.detail("No license in response"),
|
||||
)
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
@@ -201,10 +212,16 @@ async def claim_license(
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(str(e)),
|
||||
)
|
||||
except requests.RequestException:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
status_code=OnyxErrorCode.BAD_GATEWAY.status_code,
|
||||
detail=OnyxErrorCode.BAD_GATEWAY.detail(
|
||||
"Failed to connect to license server"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -222,8 +239,10 @@ async def upload_license(
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
|
||||
"License upload is only available for self-hosted deployments"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -234,14 +253,20 @@ async def upload_license(
|
||||
# Remove any stray whitespace/newlines from user input
|
||||
license_data = license_data.strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.INVALID_INPUT.status_code,
|
||||
detail=OnyxErrorCode.INVALID_INPUT.detail("Invalid license file format"),
|
||||
)
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(str(e)),
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
@@ -298,8 +323,10 @@ async def delete_license(
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail(
|
||||
"License deletion is only available for self-hosted deployments"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -116,9 +117,14 @@ async def create_customer_portal_session(
|
||||
try:
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"stripe_customer_portal_url": portal_url}
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Failed to create customer portal session"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
@@ -134,9 +140,14 @@ async def create_checkout_session(
|
||||
try:
|
||||
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
|
||||
return {"stripe_checkout_url": checkout_url}
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkout session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Failed to create checkout session"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
@@ -147,15 +158,25 @@ async def create_subscription_session(
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.VALIDATION_ERROR.status_code,
|
||||
detail=OnyxErrorCode.VALIDATION_ERROR.detail("Tenant ID not found"),
|
||||
)
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Failed to create subscription session"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -187,8 +208,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Invalid Stripe publishable key format"
|
||||
),
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
@@ -196,8 +219,10 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Stripe publishable key is not configured"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -209,14 +234,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Invalid Stripe publishable key format"
|
||||
),
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
status_code=OnyxErrorCode.INTERNAL_ERROR.status_code,
|
||||
detail=OnyxErrorCode.INTERNAL_ERROR.detail(
|
||||
"Failed to fetch Stripe publishable key"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -819,9 +819,7 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
|
||||
@@ -532,7 +532,6 @@ def fetch_default_model(
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
|
||||
0
backend/onyx/error_handling/__init__.py
Normal file
0
backend/onyx/error_handling/__init__.py
Normal file
103
backend/onyx/error_handling/error_codes.py
Normal file
103
backend/onyx/error_handling/error_codes.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Standardized error codes for the Onyx backend.
|
||||
|
||||
Usage:
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
|
||||
raise HTTPException(
|
||||
status_code=OnyxErrorCode.UNAUTHENTICATED.status_code,
|
||||
detail=OnyxErrorCode.UNAUTHENTICATED.detail("Token expired"),
|
||||
)
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OnyxErrorCode(Enum):
|
||||
"""
|
||||
Each member is a tuple of (error_code_string, http_status_code).
|
||||
|
||||
The error_code_string is a stable, machine-readable identifier that
|
||||
API consumers can match on. The http_status_code is the default HTTP
|
||||
status to return.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication (401)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHENTICATED = ("UNAUTHENTICATED", 401)
|
||||
INVALID_TOKEN = ("INVALID_TOKEN", 401)
|
||||
TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401)
|
||||
CSRF_FAILURE = ("CSRF_FAILURE", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authorization (403)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHORIZED = ("UNAUTHORIZED", 403)
|
||||
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
|
||||
ADMIN_ONLY = ("ADMIN_ONLY", 403)
|
||||
EE_REQUIRED = ("EE_REQUIRED", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation / Bad Request (400)
|
||||
# ------------------------------------------------------------------
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
# ------------------------------------------------------------------
|
||||
NOT_FOUND = ("NOT_FOUND", 404)
|
||||
CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404)
|
||||
CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404)
|
||||
PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404)
|
||||
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
|
||||
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
|
||||
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conflict (409)
|
||||
# ------------------------------------------------------------------
|
||||
CONFLICT = ("CONFLICT", 409)
|
||||
DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rate Limiting / Quotas (429 / 402)
|
||||
# ------------------------------------------------------------------
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400)
|
||||
CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400)
|
||||
CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server Errors (5xx)
|
||||
# ------------------------------------------------------------------
|
||||
INTERNAL_ERROR = ("INTERNAL_ERROR", 500)
|
||||
NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501)
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
|
||||
def detail(self, message: str | None = None) -> dict[str, str]:
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"message": message or self.code,
|
||||
}
|
||||
@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
llm = get_default_llm()
|
||||
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
@@ -13,9 +12,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HEALTH_CACHE_TTL_SECONDS = 30
|
||||
_health_cache: dict[str, tuple[float, bool]] = {}
|
||||
|
||||
|
||||
class FileInput(TypedDict):
|
||||
"""Input file to be staged in execution workspace"""
|
||||
@@ -102,32 +98,16 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self, use_cache: bool = False) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy
|
||||
|
||||
Args:
|
||||
use_cache: When True, return a cached result if available and
|
||||
within the TTL window. The cache is always populated
|
||||
after a live request regardless of this flag.
|
||||
"""
|
||||
if use_cache:
|
||||
cached = _health_cache.get(self.base_url)
|
||||
if cached is not None:
|
||||
cached_at, cached_result = cached
|
||||
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
|
||||
return cached_result
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
result = response.json().get("status") == "ok"
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
result = False
|
||||
|
||||
_health_cache[self.base_url] = (time.monotonic(), result)
|
||||
return result
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
|
||||
@@ -107,11 +107,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
client = CodeInterpreterClient()
|
||||
return client.health(use_cache=True)
|
||||
return server.server_enabled
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -1027,13 +1027,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_GET(self) -> None:
|
||||
self._capture("GET", b"")
|
||||
if self.path == "/health":
|
||||
self._respond_json(200, {"status": "ok"})
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_DELETE(self) -> None:
|
||||
self._capture("DELETE", b"")
|
||||
self.send_response(200)
|
||||
@@ -1114,14 +1107,6 @@ def mock_ci_server() -> Generator[MockCodeInterpreterServer, None, None]:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _attach_python_tool_to_default_persona(db_session: Session) -> None:
|
||||
"""Ensure the default persona (id=0) has the PythonTool attached."""
|
||||
|
||||
@@ -103,7 +103,9 @@ class TestCreateCheckoutSession:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert "Stripe error" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert exc_info.value.detail["error_code"] == "BILLING_SERVICE_ERROR"
|
||||
assert exc_info.value.detail["message"] == "Stripe error"
|
||||
|
||||
|
||||
class TestCreateCustomerPortalSession:
|
||||
@@ -134,7 +136,8 @@ class TestCreateCustomerPortalSession:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "No license found" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert exc_info.value.detail["message"] == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.create_portal_service")
|
||||
@@ -241,7 +244,8 @@ class TestUpdateSeats:
|
||||
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "No license found" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert exc_info.value.detail["message"] == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.get_used_seats")
|
||||
@@ -314,7 +318,9 @@ class TestUpdateSeats:
|
||||
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Cannot reduce below 10 seats" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert exc_info.value.detail["error_code"] == "BILLING_SERVICE_ERROR"
|
||||
assert exc_info.value.detail["message"] == "Cannot reduce below 10 seats"
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
@@ -344,7 +350,8 @@ class TestCircuitBreaker:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "Connect to Stripe" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert "Connect to Stripe" in exc_info.value.detail["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.MULTI_TENANT", False)
|
||||
|
||||
@@ -70,7 +70,10 @@ class TestGetStripePublishableKey:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert (
|
||||
exc_info.value.detail["message"] == "Invalid Stripe publishable key format"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -96,7 +99,10 @@ class TestGetStripePublishableKey:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert (
|
||||
exc_info.value.detail["message"] == "Invalid Stripe publishable key format"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -118,7 +124,10 @@ class TestGetStripePublishableKey:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Failed to fetch Stripe publishable key" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert (
|
||||
exc_info.value.detail["message"] == "Failed to fetch Stripe publishable key"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -133,7 +142,8 @@ class TestGetStripePublishableKey:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "not configured" in exc_info.value.detail
|
||||
assert isinstance(exc_info.value.detail, dict)
|
||||
assert "not configured" in exc_info.value.detail["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
|
||||
@@ -1,37 +1,25 @@
|
||||
"""Tests for PythonTool availability based on server_enabled flag and health check.
|
||||
"""Tests for PythonTool availability based on server_enabled flag.
|
||||
|
||||
Verifies that PythonTool reports itself as unavailable when either:
|
||||
- CODE_INTERPRETER_BASE_URL is not set, or
|
||||
- CodeInterpreterServer.server_enabled is False in the database, or
|
||||
- The Code Interpreter service health check fails.
|
||||
|
||||
Also verifies that the health check result is cached with a TTL.
|
||||
- CodeInterpreterServer.server_enabled is False in the database.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
|
||||
CLIENT_MODULE = "onyx.tools.tool_implementations.python.code_interpreter_client"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", None)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
None,
|
||||
)
|
||||
def test_python_tool_unavailable_without_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -39,7 +27,10 @@ def test_python_tool_unavailable_without_base_url() -> None:
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "")
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"",
|
||||
)
|
||||
def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -52,8 +43,13 @@ def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
@@ -68,15 +64,18 @@ def test_python_tool_unavailable_when_server_disabled(
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check determines availability when URL + server are OK
|
||||
# Available when both conditions are met
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_available_when_health_check_passes(
|
||||
mock_client_cls: MagicMock,
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
@@ -85,120 +84,5 @@ def test_python_tool_available_when_health_check_passes(
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = True
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_unavailable_when_health_check_fails(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = False
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check is NOT reached when preconditions fail
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_health_check_not_called_when_server_disabled(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = False
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client_cls.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check caching (tested at the client level)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_health_check_cached_on_second_call() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health(use_cache=True) is True
|
||||
assert client.health(use_cache=True) is True
|
||||
# Only one HTTP call — the second used the cache
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
@patch(f"{CLIENT_MODULE}.time")
|
||||
def test_health_check_refreshed_after_ttl_expires(mock_time: MagicMock) -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
_HEALTH_CACHE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
# First call at t=0 — cache miss
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Second call within TTL — cache hit
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS - 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Third call after TTL — cache miss, fresh request
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS + 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
def test_health_check_no_cache_by_default() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health() is True
|
||||
assert client.health() is True
|
||||
# Both calls hit the network when use_cache=False (default)
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 2,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "greptile.json\n",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"customContext": {
|
||||
"other": [
|
||||
{
|
||||
"scope": [],
|
||||
"content": ""
|
||||
}
|
||||
],
|
||||
"rules": [
|
||||
{
|
||||
"scope": [],
|
||||
"rule": ""
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"scope": [],
|
||||
"path": "",
|
||||
"description": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@
|
||||
"test:verbose": "jest --verbose",
|
||||
"test:ci": "jest --ci --maxWorkers=2 --silent --bail",
|
||||
"test:changed": "jest --onlyChanged",
|
||||
"test:diff": "jest --changedSince=main",
|
||||
"test:debug": "node --inspect-brk node_modules/.bin/jest --runInBand"
|
||||
},
|
||||
"dependencies": {
|
||||
|
||||
@@ -187,7 +187,6 @@ export const fetchOllamaModels = async (
|
||||
api_base: apiBase,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -20,7 +20,6 @@ import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/
|
||||
import { Credential } from "@/lib/connectors/credentials";
|
||||
import { useFederatedConnectors } from "@/lib/hooks";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useToastFromQuery } from "@/hooks/useToast";
|
||||
|
||||
export default function ConnectorWrapper({
|
||||
connector,
|
||||
@@ -30,13 +29,6 @@ export default function ConnectorWrapper({
|
||||
const searchParams = useSearchParams();
|
||||
const mode = searchParams?.get("mode"); // 'federated' or 'regular'
|
||||
|
||||
useToastFromQuery({
|
||||
oauth_failed: {
|
||||
message: "OAuth authentication failed. Please try again.",
|
||||
type: "error",
|
||||
},
|
||||
});
|
||||
|
||||
// Check if the connector is valid
|
||||
if (!isValidSource(connector)) {
|
||||
return (
|
||||
|
||||
@@ -2,6 +2,10 @@ import { getDomain } from "@/lib/redirectSS";
|
||||
import { buildUrl } from "@/lib/utilsSS";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
import { cookies } from "next/headers";
|
||||
import {
|
||||
GMAIL_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
} from "@/lib/constants";
|
||||
import {
|
||||
CRAFT_OAUTH_COOKIE_NAME,
|
||||
CRAFT_CONFIGURE_PATH,
|
||||
@@ -11,7 +15,6 @@ import { processCookies } from "@/lib/userSS";
|
||||
export const GET = async (request: NextRequest) => {
|
||||
const requestCookies = await cookies();
|
||||
const connector = request.url.includes("gmail") ? "gmail" : "google-drive";
|
||||
|
||||
const callbackEndpoint = `/manage/connector/${connector}/callback`;
|
||||
const url = new URL(buildUrl(callbackEndpoint));
|
||||
url.search = request.nextUrl.search;
|
||||
@@ -23,12 +26,7 @@ export const GET = async (request: NextRequest) => {
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return NextResponse.redirect(
|
||||
new URL(
|
||||
`/admin/connectors/${connector}?message=oauth_failed`,
|
||||
getDomain(request)
|
||||
)
|
||||
);
|
||||
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
|
||||
}
|
||||
|
||||
// Check for build mode OAuth flag (redirects to build admin panel)
|
||||
@@ -42,7 +40,16 @@ export const GET = async (request: NextRequest) => {
|
||||
return redirectResponse;
|
||||
}
|
||||
|
||||
return NextResponse.redirect(
|
||||
new URL(`/admin/connectors/${connector}`, getDomain(request))
|
||||
);
|
||||
const authCookieName =
|
||||
connector === "gmail"
|
||||
? GMAIL_AUTH_IS_ADMIN_COOKIE_NAME
|
||||
: GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME;
|
||||
|
||||
if (requestCookies.get(authCookieName)?.value?.toLowerCase() === "true") {
|
||||
return NextResponse.redirect(
|
||||
new URL(`/admin/connectors/${connector}`, getDomain(request))
|
||||
);
|
||||
}
|
||||
|
||||
return NextResponse.redirect(new URL("/user/connectors", getDomain(request)));
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { getSourceMetadata, isValidSource } from "@/lib/sources";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
@@ -9,6 +9,7 @@ import CardSection from "@/components/admin/CardSection";
|
||||
import { handleOAuthAuthorizationResponse } from "@/lib/oauth_utils";
|
||||
import { SvgKey } from "@opal/icons";
|
||||
export default function OAuthCallbackPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const [statusMessage, setStatusMessage] = useState("Processing...");
|
||||
|
||||
@@ -6,7 +6,11 @@ import { useRouter } from "next/navigation";
|
||||
import type { Route } from "next";
|
||||
import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGoogleDriveOAuth } from "@/lib/googleDrive";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
import {
|
||||
DOCS_ADMINS_PATH,
|
||||
GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
} from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { TextFormField, SectionHeader } from "@/components/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
@@ -588,6 +592,11 @@ export const DriveAuthSection = ({
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
// cookie used by callback to determine where to finally redirect to
|
||||
Cookies.set(GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
|
||||
@@ -7,7 +7,10 @@ import { useRouter } from "next/navigation";
|
||||
import type { Route } from "next";
|
||||
import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGmailOAuth } from "@/lib/gmail";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
import {
|
||||
DOCS_ADMINS_PATH,
|
||||
GMAIL_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
} from "@/lib/constants";
|
||||
import { CRAFT_OAUTH_COOKIE_NAME } from "@/app/craft/v1/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { TextFormField, SectionHeader } from "@/components/Field";
|
||||
@@ -599,6 +602,9 @@ export const GmailAuthSection = ({
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
Cookies.set(GMAIL_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
if (buildMode) {
|
||||
Cookies.set(CRAFT_OAUTH_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
|
||||
76
web/src/app/ee/admin/theme/ImageUpload.tsx
Normal file
76
web/src/app/ee/admin/theme/ImageUpload.tsx
Normal file
@@ -0,0 +1,76 @@
|
||||
import { SubLabel } from "@/components/Field";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useEffect, useState } from "react";
|
||||
import Dropzone from "react-dropzone";
|
||||
|
||||
export function ImageUpload({
|
||||
selectedFile,
|
||||
setSelectedFile,
|
||||
}: {
|
||||
selectedFile: File | null;
|
||||
setSelectedFile: (file: File) => void;
|
||||
}) {
|
||||
const [tmpImageUrl, setTmpImageUrl] = useState<string>("");
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedFile) {
|
||||
setTmpImageUrl(URL.createObjectURL(selectedFile));
|
||||
} else {
|
||||
setTmpImageUrl("");
|
||||
}
|
||||
}, [selectedFile]);
|
||||
|
||||
const [dragActive, setDragActive] = useState(false);
|
||||
|
||||
return (
|
||||
<Dropzone
|
||||
onDrop={(acceptedFiles) => {
|
||||
if (acceptedFiles.length !== 1) {
|
||||
toast.error("Only one file can be uploaded at a time");
|
||||
return;
|
||||
}
|
||||
|
||||
const acceptedFile = acceptedFiles[0];
|
||||
if (acceptedFile === undefined) {
|
||||
toast.error("acceptedFile cannot be undefined");
|
||||
return;
|
||||
}
|
||||
|
||||
setTmpImageUrl(URL.createObjectURL(acceptedFile));
|
||||
setSelectedFile(acceptedFile);
|
||||
setDragActive(false);
|
||||
}}
|
||||
onDragLeave={() => setDragActive(false)}
|
||||
onDragEnter={() => setDragActive(true)}
|
||||
>
|
||||
{({ getRootProps, getInputProps }) => (
|
||||
<section>
|
||||
<div
|
||||
{...getRootProps()}
|
||||
className={
|
||||
"flex flex-col items-center w-full px-4 py-12 rounded " +
|
||||
"shadow-lg tracking-wide border border-border cursor-pointer" +
|
||||
(dragActive ? " border-accent" : "")
|
||||
}
|
||||
>
|
||||
<input {...getInputProps()} />
|
||||
<b className="text-text-darker">
|
||||
Drag and drop a .png or .jpg file, or click to select a file!
|
||||
</b>
|
||||
</div>
|
||||
|
||||
{tmpImageUrl && (
|
||||
<div className="mt-4 mb-8">
|
||||
<SubLabel>Uploaded Image:</SubLabel>
|
||||
<img
|
||||
alt="Uploaded Image"
|
||||
src={tmpImageUrl}
|
||||
className="mt-4 max-w-xs max-h-64"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
)}
|
||||
</Dropzone>
|
||||
);
|
||||
}
|
||||
@@ -16,6 +16,8 @@ import { cn } from "@/lib/utils";
|
||||
|
||||
const CsvContent: React.FC<ContentComponentProps> = ({
|
||||
fileDescriptor,
|
||||
isLoading,
|
||||
fadeIn,
|
||||
expanded = false,
|
||||
}) => {
|
||||
const [data, setData] = useState<Record<string, string>[]>([]);
|
||||
@@ -92,7 +94,7 @@ const CsvContent: React.FC<ContentComponentProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
if (isFetching) {
|
||||
if (isLoading || isFetching) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-[300px]">
|
||||
<SimpleLoader />
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// ExpandableContentWrapper
|
||||
import React, { useState } from "react";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { SvgDownloadCloud, SvgFold, SvgMaximize2, SvgX } from "@opal/icons";
|
||||
import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card";
|
||||
import { Button } from "@opal/components";
|
||||
@@ -17,6 +17,8 @@ export interface ExpandableContentWrapperProps {
|
||||
|
||||
export interface ContentComponentProps {
|
||||
fileDescriptor: FileDescriptor;
|
||||
isLoading: boolean;
|
||||
fadeIn: boolean;
|
||||
expanded?: boolean;
|
||||
}
|
||||
|
||||
@@ -26,9 +28,24 @@ export default function ExpandableContentWrapper({
|
||||
ContentComponent,
|
||||
}: ExpandableContentWrapperProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [fadeIn, setFadeIn] = useState(false);
|
||||
|
||||
const toggleExpand = () => setExpanded((prev) => !prev);
|
||||
|
||||
// Prevent a jarring fade in
|
||||
useEffect(() => {
|
||||
setTimeout(() => setIsLoading(false), 300);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isLoading) {
|
||||
setTimeout(() => setFadeIn(true), 50);
|
||||
} else {
|
||||
setFadeIn(false);
|
||||
}
|
||||
}, [isLoading]);
|
||||
|
||||
const downloadFile = () => {
|
||||
const a = document.createElement("a");
|
||||
a.href = `api/chat/file/${fileDescriptor.id}`;
|
||||
@@ -86,6 +103,8 @@ export default function ExpandableContentWrapper({
|
||||
{!expanded && (
|
||||
<ContentComponent
|
||||
fileDescriptor={fileDescriptor}
|
||||
isLoading={isLoading}
|
||||
fadeIn={fadeIn}
|
||||
expanded={expanded}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -118,7 +118,6 @@ export interface BedrockFetchParams {
|
||||
export interface OllamaFetchParams {
|
||||
api_base?: string;
|
||||
provider_name?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OpenRouterFetchParams {
|
||||
|
||||
@@ -28,6 +28,11 @@ export const NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED =
|
||||
|
||||
export const TENANT_ID_COOKIE_NAME = "onyx_tid";
|
||||
|
||||
export const GMAIL_AUTH_IS_ADMIN_COOKIE_NAME = "gmail_auth_is_admin";
|
||||
|
||||
export const GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME =
|
||||
"google_drive_auth_is_admin";
|
||||
|
||||
export const SEARCH_TYPE_COOKIE_NAME = "search_type";
|
||||
export const AGENTIC_SEARCH_TYPE_COOKIE_NAME = "agentic_type";
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import { Button as OpalButton } from "@opal/components";
|
||||
import InputAvatar from "@/refresh-components/inputs/InputAvatar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgCheckCircle, SvgEdit, SvgUser, SvgX } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
|
||||
export default function NonAdminStep() {
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
@@ -55,26 +54,17 @@ export default function NonAdminStep() {
|
||||
className="flex items-center justify-between w-full min-h-11 py-1 pl-3 pr-2 bg-background-tint-00 rounded-16 shadow-01 mb-2"
|
||||
aria-label="non-admin-confirmation"
|
||||
>
|
||||
<ContentAction
|
||||
icon={({ className, ...props }) => (
|
||||
<SvgCheckCircle
|
||||
className={cn(className, "stroke-status-success-05")}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
title="You're all set!"
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
paddingVariant="fit"
|
||||
rightChildren={
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgX}
|
||||
onClick={handleDismissConfirmation}
|
||||
/>
|
||||
}
|
||||
<div className="flex items-center gap-1">
|
||||
<SvgCheckCircle className="w-4 h-4 stroke-status-success-05" />
|
||||
<Text as="p" text03 mainUiBody>
|
||||
You're all set!
|
||||
</Text>
|
||||
</div>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgX}
|
||||
onClick={handleDismissConfirmation}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
@@ -85,36 +75,39 @@ export default function NonAdminStep() {
|
||||
role="group"
|
||||
aria-label="non-admin-name-prompt"
|
||||
>
|
||||
<ContentAction
|
||||
icon={SvgUser}
|
||||
title="What should Onyx call you?"
|
||||
description="We will display this name in the app."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="fit"
|
||||
rightChildren={
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<InputTypeIn
|
||||
ref={inputRef}
|
||||
placeholder="Your name"
|
||||
value={name || ""}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) =>
|
||||
setName(e.target.value)
|
||||
}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && name && name.trim().length > 0) {
|
||||
e.preventDefault();
|
||||
handleSave();
|
||||
}
|
||||
}}
|
||||
className="w-[26%] min-w-40"
|
||||
/>
|
||||
<Button disabled={name === ""} onClick={handleSave}>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
<div className="flex items-center gap-1 h-full">
|
||||
<div className="h-full p-0.5">
|
||||
<SvgUser className="w-4 h-4 stroke-text-03" />
|
||||
</div>
|
||||
<div>
|
||||
<Text as="p" text04 mainUiAction>
|
||||
What should Onyx call you?
|
||||
</Text>
|
||||
<Text as="p" text03 secondaryBody>
|
||||
We will display this name in the app.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<InputTypeIn
|
||||
ref={inputRef}
|
||||
placeholder="Your name"
|
||||
value={name || ""}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) =>
|
||||
setName(e.target.value)
|
||||
}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && name && name.trim().length > 0) {
|
||||
e.preventDefault();
|
||||
handleSave();
|
||||
}
|
||||
}}
|
||||
className="w-[26%] min-w-40"
|
||||
/>
|
||||
<Button disabled={name === ""} onClick={handleSave}>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"use client";
|
||||
|
||||
import { memo, useState, useCallback } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
@@ -14,7 +12,6 @@ import {
|
||||
import { Disabled } from "@/refresh-components/Disabled";
|
||||
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import { SvgCheckCircle, SvgCpu, SvgExternalLink } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
|
||||
type LLMStepProps = {
|
||||
state: OnboardingState;
|
||||
@@ -125,14 +122,21 @@ const LLMStepInner = ({
|
||||
className="flex flex-col items-center justify-between w-full p-1 rounded-16 border border-border-01 bg-background-tint-00"
|
||||
aria-label="onboarding-llm-step"
|
||||
>
|
||||
<ContentAction
|
||||
icon={SvgCpu}
|
||||
title="Connect your LLM models"
|
||||
description="Onyx supports both self-hosted models and popular providers."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="lg"
|
||||
rightChildren={
|
||||
<div className="flex gap-2 justify-between h-full w-full">
|
||||
<div className="flex mx-2 mt-2 gap-1">
|
||||
<div className="h-full p-0.5">
|
||||
<SvgCpu className="w-4 h-4 stroke-text-03" />
|
||||
</div>
|
||||
<div>
|
||||
<Text as="p" text04 mainUiAction>
|
||||
Connect your LLM models
|
||||
</Text>
|
||||
<Text as="p" text03 secondaryBody>
|
||||
Onyx supports both self-hosted models and popular providers.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div className="p-0.5">
|
||||
<Button
|
||||
tertiary
|
||||
rightIcon={SvgExternalLink}
|
||||
@@ -141,8 +145,8 @@ const LLMStepInner = ({
|
||||
>
|
||||
View in Admin Panel
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<Separator />
|
||||
<div className="flex flex-wrap gap-1 [&>*:last-child:nth-child(odd)]:basis-full">
|
||||
{isLoading ? (
|
||||
|
||||
@@ -8,7 +8,6 @@ import InputAvatar from "@/refresh-components/inputs/InputAvatar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { SvgCheckCircle, SvgEdit, SvgUser } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
|
||||
export interface NameStepProps {
|
||||
state: OnboardingState;
|
||||
@@ -41,23 +40,26 @@ const NameStep = React.memo(
|
||||
role="group"
|
||||
aria-label="onboarding-name-step"
|
||||
>
|
||||
<ContentAction
|
||||
icon={SvgUser}
|
||||
title="What should Onyx call you?"
|
||||
description="We will display this name in the app."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="fit"
|
||||
rightChildren={
|
||||
<InputTypeIn
|
||||
ref={inputRef}
|
||||
placeholder="Your name"
|
||||
value={userName || ""}
|
||||
onChange={(e) => updateName(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="max-w-60"
|
||||
/>
|
||||
}
|
||||
<div className="flex items-center gap-1 h-full">
|
||||
<div className="h-full p-0.5">
|
||||
<SvgUser className="w-4 h-4 stroke-text-03" />
|
||||
</div>
|
||||
<div>
|
||||
<Text as="p" text04 mainUiAction>
|
||||
What should Onyx call you?
|
||||
</Text>
|
||||
<Text as="p" text03 secondaryBody>
|
||||
We will display this name in the app.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<InputTypeIn
|
||||
ref={inputRef}
|
||||
placeholder="Your name"
|
||||
value={userName || ""}
|
||||
onChange={(e) => updateName(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="max-w-60"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -25,7 +25,7 @@ export interface ActionItemProps {
|
||||
disabled: boolean;
|
||||
isForced: boolean;
|
||||
isUnavailable?: boolean;
|
||||
tooltip?: string;
|
||||
unavailableReason?: string;
|
||||
showAdminConfigure?: boolean;
|
||||
adminConfigureHref?: string;
|
||||
adminConfigureTooltip?: string;
|
||||
@@ -47,7 +47,7 @@ export default function ActionLineItem({
|
||||
disabled,
|
||||
isForced,
|
||||
isUnavailable = false,
|
||||
tooltip,
|
||||
unavailableReason,
|
||||
showAdminConfigure = false,
|
||||
adminConfigureHref,
|
||||
adminConfigureTooltip = "Configure",
|
||||
@@ -88,7 +88,7 @@ export default function ActionLineItem({
|
||||
sourceCounts.enabled > 0 &&
|
||||
sourceCounts.enabled < sourceCounts.total;
|
||||
|
||||
const tooltipText = tooltip || tool?.description;
|
||||
const tooltipText = isUnavailable ? unavailableReason : tool?.description;
|
||||
|
||||
return (
|
||||
<SimpleTooltip tooltip={tooltipText} className="max-w-[30rem]">
|
||||
|
||||
@@ -42,45 +42,26 @@ import { useProjectsContext } from "@/providers/ProjectsContext";
|
||||
import { SvgActions, SvgChevronRight, SvgKey, SvgSliders } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
|
||||
function buildTooltipMessage(
|
||||
actionDescription: string,
|
||||
isConfigured: boolean,
|
||||
isAdmin: boolean
|
||||
) {
|
||||
const _ADMIN_CONFIGURE_MESSAGE = "Press the settings cog to enable.";
|
||||
const _USER_NOT_ADMIN_MESSAGE = "Ask an admin to configure.";
|
||||
|
||||
if (isConfigured) {
|
||||
return actionDescription;
|
||||
}
|
||||
|
||||
if (isAdmin) {
|
||||
return actionDescription + " " + _ADMIN_CONFIGURE_MESSAGE;
|
||||
}
|
||||
|
||||
return actionDescription + " " + _USER_NOT_ADMIN_MESSAGE;
|
||||
}
|
||||
|
||||
const TOOL_DESCRIPTIONS: Record<string, string> = {
|
||||
[SEARCH_TOOL_ID]: "Search through connected knowledge to inform the answer.",
|
||||
[IMAGE_GENERATION_TOOL_ID]: "Generate images based on a prompt.",
|
||||
[WEB_SEARCH_TOOL_ID]: "Search the web for up-to-date information.",
|
||||
[PYTHON_TOOL_ID]: "Execute code for complex analysis.",
|
||||
const UNAVAILABLE_TOOL_TOOLTIP_FALLBACK =
|
||||
"This action is not configured yet. Ask an admin to enable it.";
|
||||
const UNAVAILABLE_TOOL_TOOLTIP_ADMIN_FALLBACK =
|
||||
"This action is not configured yet. If you have access, enable it in the admin panel.";
|
||||
const UNAVAILABLE_TOOL_TOOLTIPS: Record<string, string> = {
|
||||
[IMAGE_GENERATION_TOOL_ID]:
|
||||
"Image generation requires a configured model. If you have access, set one up under Settings > Image Generation, or ask an admin.",
|
||||
[WEB_SEARCH_TOOL_ID]:
|
||||
"Web search requires a configured provider. If you have access, set one up under Settings > Web Search, or ask an admin.",
|
||||
[PYTHON_TOOL_ID]:
|
||||
"Code Interpreter requires the service to be configured with a valid base URL. If you have access, configure it in the admin panel, or ask an admin.",
|
||||
};
|
||||
|
||||
const DEFAULT_TOOL_DESCRIPTION = "This action is not configured yet.";
|
||||
|
||||
function getToolTooltip(
|
||||
tool: ToolSnapshot,
|
||||
isConfigured: boolean,
|
||||
isAdmin: boolean
|
||||
): string {
|
||||
const description =
|
||||
(tool.in_code_tool_id && TOOL_DESCRIPTIONS[tool.in_code_tool_id]) ||
|
||||
tool.description ||
|
||||
DEFAULT_TOOL_DESCRIPTION;
|
||||
return buildTooltipMessage(description, isConfigured, isAdmin);
|
||||
}
|
||||
const getUnavailableToolTooltip = (
|
||||
inCodeToolId?: string | null,
|
||||
canAdminConfigure?: boolean
|
||||
) =>
|
||||
(inCodeToolId && UNAVAILABLE_TOOL_TOOLTIPS[inCodeToolId]) ??
|
||||
(canAdminConfigure
|
||||
? UNAVAILABLE_TOOL_TOOLTIP_ADMIN_FALLBACK
|
||||
: UNAVAILABLE_TOOL_TOOLTIP_FALLBACK);
|
||||
|
||||
const ADMIN_CONFIG_LINKS: Record<string, { href: string; tooltip: string }> = {
|
||||
[IMAGE_GENERATION_TOOL_ID]: {
|
||||
@@ -91,10 +72,6 @@ const ADMIN_CONFIG_LINKS: Record<string, { href: string; tooltip: string }> = {
|
||||
href: "/admin/configuration/web-search",
|
||||
tooltip: "Configure Web Search",
|
||||
},
|
||||
[PYTHON_TOOL_ID]: {
|
||||
href: "/admin/configuration/code-interpreter",
|
||||
tooltip: "Configure Code Interpreter",
|
||||
},
|
||||
KnowledgeGraphTool: {
|
||||
href: "/admin/kg",
|
||||
tooltip: "Configure Knowledge Graph",
|
||||
@@ -915,11 +892,14 @@ export default function ActionsPopover({
|
||||
disabled={disabledToolIds.includes(tool.id)}
|
||||
isForced={forcedToolIds.includes(tool.id)}
|
||||
isUnavailable={isUnavailable}
|
||||
tooltip={getToolTooltip(
|
||||
tool,
|
||||
isToolAvailable,
|
||||
canAdminConfigure
|
||||
)}
|
||||
unavailableReason={
|
||||
isUnavailable
|
||||
? getUnavailableToolTooltip(
|
||||
tool.in_code_tool_id,
|
||||
canAdminConfigure
|
||||
)
|
||||
: undefined
|
||||
}
|
||||
showAdminConfigure={!!adminConfigureInfo}
|
||||
adminConfigureHref={adminConfigureInfo?.href}
|
||||
adminConfigureTooltip={adminConfigureInfo?.tooltip}
|
||||
|
||||
@@ -23,9 +23,8 @@ import {
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
|
||||
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
@@ -62,16 +61,14 @@ function OllamaModalContent({
|
||||
}: OllamaModalContentProps) {
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
|
||||
const fetchModels = useCallback(
|
||||
(apiBase: string, signal: AbortSignal) => {
|
||||
useEffect(() => {
|
||||
if (formikProps.values.api_base) {
|
||||
setIsLoadingModels(true);
|
||||
fetchOllamaModels({
|
||||
api_base: apiBase,
|
||||
api_base: formikProps.values.api_base,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
signal,
|
||||
})
|
||||
.then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
console.error("Error fetching models:", data.error);
|
||||
setFetchedModels([]);
|
||||
@@ -80,32 +77,14 @@ function OllamaModalContent({
|
||||
setFetchedModels(data.models);
|
||||
})
|
||||
.finally(() => {
|
||||
if (!signal.aborted) {
|
||||
setIsLoadingModels(false);
|
||||
}
|
||||
setIsLoadingModels(false);
|
||||
});
|
||||
},
|
||||
[existingLlmProvider?.name, setFetchedModels]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(fetchModels, 500),
|
||||
[fetchModels]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (formikProps.values.api_base) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(formikProps.values.api_base, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
setIsLoadingModels(false);
|
||||
setFetchedModels([]);
|
||||
}
|
||||
}, [formikProps.values.api_base, debouncedFetchModels, setFetchedModels]);
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
existingLlmProvider?.name,
|
||||
setFetchedModels,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
|
||||
Reference in New Issue
Block a user