Compare commits

..

1 Commits

Author SHA1 Message Date
Nik
20b2da4104 refactor: replace HTTPException with OnyxError in MCP API 2026-03-04 23:11:49 -08:00

View File

@@ -12,7 +12,6 @@ from urllib.parse import urlparse
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from mcp.client.auth import OAuthClientProvider
from mcp.client.auth import TokenStorage
@@ -58,6 +57,8 @@ from onyx.db.models import User
from onyx.db.tools import create_tool__no_commit
from onyx.db.tools import delete_tool__no_commit
from onyx.db.tools import get_tools_by_mcp_server_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.mcp.models import MCPApiKeyResponse
from onyx.server.features.mcp.models import MCPAuthTemplate
@@ -139,7 +140,7 @@ class OnyxTokenStorage(TokenStorage):
def _ensure_connection_config(self, db_session: Session) -> MCPConnectionConfig:
config = get_connection_config_by_id(self.connection_config_id, db_session)
if config is None:
raise HTTPException(status_code=404, detail="Connection config not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Connection config not found")
return config
async def get_tokens(self) -> OAuthToken | None:
@@ -379,16 +380,16 @@ async def _connect_oauth(
server_id = int(request.server_id)
mcp_server = get_mcp_server_by_id(server_id, db)
except Exception:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
if is_admin:
_ensure_mcp_server_owner_or_admin(mcp_server, user)
if mcp_server.auth_type != MCPAuthenticationType.OAUTH:
auth_type_str = mcp_server.auth_type.value if mcp_server.auth_type else "None"
raise HTTPException(
status_code=400,
detail=f"Server was configured with authentication type {auth_type_str}",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Server was configured with authentication type {auth_type_str}",
)
# Create admin config with client info if provided
@@ -410,9 +411,9 @@ async def _connect_oauth(
if mcp_server.admin_connection_config_id is None:
if not is_admin:
raise HTTPException(
status_code=400,
detail="Admin connection config not found for this server",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Admin connection config not found for this server",
)
admin_config = create_connection_config(
@@ -453,9 +454,9 @@ async def _connect_oauth(
# Ensure we have a trailing slash for the MCP endpoint
if mcp_server.transport is None:
raise HTTPException(
status_code=400,
detail="MCP server transport is not configured",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"MCP server transport is not configured",
)
# always make a http request for the initial probe
@@ -533,10 +534,13 @@ async def _connect_oauth(
)
except Exception as e:
logger.error(f"OAuth initialization failed during timeout: {e}")
raise HTTPException(
status_code=400, detail=f"OAuth initialization failed: {str(e)}"
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"OAuth initialization failed: {str(e)}",
)
raise HTTPException(status_code=400, detail="Auth URL retrieval timed out")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, "Auth URL retrieval timed out"
)
logger.info(
f"Connected to auth url: {oauth_url} for mcp server: {mcp_server.name}"
@@ -558,8 +562,9 @@ async def _connect_oauth(
saved_e = e
logger.error(f"OAuth initialization failed: {saved_e}")
# If initialize failed and we also didn't get an auth URL, surface an error
raise HTTPException(
status_code=400, detail=f"Failed to initialize OAuth client: {str(saved_e)}"
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Failed to initialize OAuth client: {str(saved_e)}",
)
return MCPUserOAuthConnectResponse(
@@ -590,20 +595,20 @@ async def process_oauth_callback(
code = callback_data.get("code")
user_id = str(user.id)
if not state:
raise HTTPException(status_code=400, detail="Missing state parameter")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Missing state parameter")
if not code:
raise HTTPException(status_code=400, detail="Missing code parameter")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Missing code parameter")
stored_data = cast(bytes, redis_client.get(key_state(user_id)))
if not stored_data:
raise HTTPException(
status_code=400, detail="Invalid or expired state parameter"
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, "Invalid or expired state parameter"
)
state_data = MCPOauthState.model_validate_json(stored_data)
try:
server_id = state_data.server_id
mcp_server = get_mcp_server_by_id(server_id, db_session)
except Exception:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
user_id = str(user.id)
@@ -615,9 +620,9 @@ async def process_oauth_callback(
admin_config = mcp_server.admin_connection_config
if admin_config is None:
raise HTTPException(
status_code=400,
detail="Server referenced by callback is not configured, try recreating",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Server referenced by callback is not configured, try recreating",
)
# Run the blocking blpop operation in a thread pool to avoid blocking the event loop
@@ -629,12 +634,14 @@ async def process_oauth_callback(
lambda: r.blpop([key_tokens(str(admin_config_id))], timeout=OAUTH_WAIT_SECONDS),
)
if tokens_raw is None:
raise HTTPException(status_code=400, detail="No tokens found")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No tokens found")
tokens_bytes = cast(tuple[bytes, bytes], tokens_raw)
tokens = OAuthToken.model_validate_json(tokens_bytes[1].decode())
if not tokens.access_token:
raise HTTPException(status_code=400, detail="No access_token in OAuth response")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, "No access_token in OAuth response"
)
db_session.commit()
@@ -667,12 +674,12 @@ def save_user_credentials(
server_id = request.server_id
mcp_server = get_mcp_server_by_id(server_id, db_session)
except Exception:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
if mcp_server.auth_type == "none":
raise HTTPException(
status_code=400,
detail="Server does not require authentication",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Server does not require authentication",
)
email = user.email
@@ -682,9 +689,9 @@ def save_user_credentials(
if not auth_template:
# Fallback to simple API key storage for servers without templates
if "api_key" not in request.credentials:
raise HTTPException(
status_code=400,
detail="No authentication template found and no api_key provided",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No authentication template found and no api_key provided",
)
config_data = MCPConnectionData(
headers={"Authorization": f"Bearer {request.credentials['api_key']}"},
@@ -709,9 +716,9 @@ def save_user_credentials(
except Exception as e:
logger.error(f"Failed to process authentication template: {e}")
raise HTTPException(
status_code=400,
detail=f"Failed to process authentication template: {str(e)}",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Failed to process authentication template: {str(e)}",
)
# Test the credentials before saving
@@ -746,17 +753,17 @@ def save_user_credentials(
validation_tested = True
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Credentials validation failed: {test_message}",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Credentials validation failed: {test_message}",
)
else:
validation_message = (
f"Credentials saved and validated successfully. {test_message}"
)
except HTTPException:
raise # Re-raise HTTP exceptions
except OnyxError:
raise # Re-raise OnyxError exceptions
except Exception as e:
logger.warning(
f"Could not validate credentials for server {mcp_server.name}: {e}"
@@ -788,7 +795,7 @@ def save_user_credentials(
except Exception as e:
logger.error(f"Failed to save user credentials: {e}")
raise HTTPException(status_code=500, detail="Failed to save user credentials")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to save user credentials")
class MCPToolDescription(BaseModel):
@@ -814,9 +821,9 @@ def _ensure_mcp_server_owner_or_admin(server: DbMCPServer, user: User) -> None:
logger.info(f"User email: {user.email} server.owner={server.owner}")
if server.owner != user.email:
raise HTTPException(
status_code=403,
detail="Curators can only modify MCP servers that they have created.",
raise OnyxError(
OnyxErrorCode.UNAUTHORIZED,
"Curators can only modify MCP servers that they have created.",
)
@@ -1004,10 +1011,10 @@ def get_mcp_servers_for_assistant(
return MCPServersResponse(assistant_id=assistant_id, mcp_servers=mcp_servers)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid assistant ID")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Invalid assistant ID")
except Exception as e:
logger.error(f"Failed to fetch MCP servers: {e}")
raise HTTPException(status_code=500, detail="Failed to fetch MCP servers")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to fetch MCP servers")
@router.get("/servers", response_model=MCPServersResponse)
@@ -1058,9 +1065,9 @@ def _get_connection_config(
)
if not connection_config:
raise HTTPException(
status_code=401,
detail="Authentication required for this MCP server",
raise OnyxError(
OnyxErrorCode.UNAUTHENTICATED,
"Authentication required for this MCP server",
)
return connection_config
@@ -1101,7 +1108,7 @@ def get_mcp_server_tools_snapshots(
# Verify the server exists
mcp_server = get_mcp_server_by_id(server_id, db)
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
_ensure_mcp_server_owner_or_admin(mcp_server, user)
@@ -1126,12 +1133,12 @@ def get_mcp_server_tools_snapshots(
)
db.commit()
if isinstance(e, HTTPException):
# Re-raise HTTP exceptions (e.g. 401, 400) so they are returned to client
if isinstance(e, OnyxError):
# Re-raise OnyxError exceptions (e.g. 401, 400) so they are returned to client
raise e
logger.error(f"Failed to discover tools for MCP server: {e}")
raise HTTPException(status_code=500, detail="Failed to discover tools")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to discover tools")
# Fetch and return tools from database
mcp_tools = get_tools_by_mcp_server_id(server_id, db, order_by_id=True)
@@ -1209,7 +1216,7 @@ def _list_mcp_tools_by_id(
# Get the MCP server
mcp_server = get_mcp_server_by_id(server_id, db)
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
if is_admin:
_ensure_mcp_server_owner_or_admin(mcp_server, user)
@@ -1225,9 +1232,9 @@ def _list_mcp_tools_by_id(
MCPAuthenticationType.NONE,
MCPAuthenticationType.PT_OAUTH,
):
raise HTTPException(
status_code=401,
detail="This MCP server is not configured yet",
raise OnyxError(
OnyxErrorCode.UNAUTHENTICATED,
"This MCP server is not configured yet",
)
user_id = str(user.id)
@@ -1251,9 +1258,9 @@ def _list_mcp_tools_by_id(
user_oauth_token = user.oauth_accounts[0].access_token
headers["Authorization"] = f"Bearer {user_oauth_token}"
else:
raise HTTPException(
status_code=401,
detail="Pass-through OAuth requires a user logged in with OAuth",
raise OnyxError(
OnyxErrorCode.UNAUTHENTICATED,
"Pass-through OAuth requires a user logged in with OAuth",
)
if connection_config:
@@ -1269,9 +1276,9 @@ def _list_mcp_tools_by_id(
server_url = mcp_server.server_url
if mcp_server.transport is None:
raise HTTPException(
status_code=400,
detail="MCP server transport is not configured",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"MCP server transport is not configured",
)
discovered_tools = discover_mcp_tools(
@@ -1341,9 +1348,9 @@ def _upsert_mcp_server(
try:
mcp_server = get_mcp_server_by_id(request.existing_server_id, db_session)
except ValueError:
raise HTTPException(
status_code=404,
detail=f"MCP server with ID {request.existing_server_id} not found",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"MCP server with ID {request.existing_server_id} not found",
)
_ensure_mcp_server_owner_or_admin(mcp_server, user)
client_info = None
@@ -1414,12 +1421,12 @@ def _upsert_mcp_server(
# Prevent duplicate server creation with same URL
normalized_url = (request.server_url or "").strip()
if not normalized_url:
raise HTTPException(status_code=400, detail="server_url is required")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "server_url is required")
if not user.email:
raise HTTPException(
status_code=400,
detail="Authenticated user email required to create MCP servers",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authenticated user email required to create MCP servers",
)
mcp_server = create_mcp_server__no_commit(
@@ -1525,9 +1532,9 @@ def _upsert_mcp_server(
db_session=db_session,
)
elif request.auth_performer == MCPAuthenticationPerformer.ADMIN:
raise HTTPException(
status_code=400,
detail="Admin authentication is not yet supported for MCP servers: user per-user",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Admin authentication is not yet supported for MCP servers: user per-user",
)
# Update server with config IDs
@@ -1579,7 +1586,7 @@ def get_mcp_server_detail(
try:
server = get_mcp_server_by_id(server_id, db_session)
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
_ensure_mcp_server_owner_or_admin(server, user)
@@ -1587,7 +1594,7 @@ def get_mcp_server_detail(
# permissions are based on access to assistants
# # Quick permission check admin or user has access
# if user and server not in user.accessible_mcp_servers and not user.is_superuser:
# raise HTTPException(status_code=403, detail="Forbidden")
# raise OnyxError(OnyxErrorCode.UNAUTHORIZED, "Forbidden")
return _db_mcp_server_to_api_mcp_server(
server,
@@ -1628,7 +1635,7 @@ def update_mcp_server_status(
try:
mcp_server = get_mcp_server_by_id(server_id, db)
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
_ensure_mcp_server_owner_or_admin(mcp_server, user)
@@ -1665,7 +1672,7 @@ def get_mcp_servers_for_admin(
except Exception as e:
logger.error(f"Failed to fetch MCP servers for admin: {type(e)}:{e}")
raise HTTPException(status_code=500, detail="Failed to fetch MCP servers")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to fetch MCP servers")
@admin_router.get("/server/{server_id}/db-tools")
@@ -1681,7 +1688,7 @@ def get_mcp_server_db_tools(
# Verify the server exists
mcp_server = get_mcp_server_by_id(server_id, db)
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
_ensure_mcp_server_owner_or_admin(mcp_server, user)
@@ -1723,8 +1730,9 @@ def upsert_mcp_server(
# Validate auth_performer for non-none auth types
if request.auth_type != MCPAuthenticationType.NONE and not request.auth_performer:
raise HTTPException(
status_code=400, detail="auth_performer is required for non-none auth types"
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"auth_performer is required for non-none auth types",
)
try:
@@ -1735,8 +1743,9 @@ def upsert_mcp_server(
not in (MCPAuthenticationType.NONE, MCPAuthenticationType.PT_OAUTH)
and mcp_server.admin_connection_config_id is None
):
raise HTTPException(
status_code=500, detail="Failed to set admin connection config"
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to set admin connection config",
)
db_session.commit()
@@ -1746,8 +1755,9 @@ def upsert_mcp_server(
)
if mcp_server.auth_type is None:
raise HTTPException(
status_code=500, detail="MCP server auth_type not configured"
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"MCP server auth_type not configured",
)
auth_type_str = mcp_server.auth_type.value
@@ -1765,13 +1775,14 @@ def upsert_mcp_server(
),
)
except HTTPException:
# Re-raise HTTP exceptions as-is
except OnyxError:
# Re-raise OnyxError as-is
raise
except Exception as e:
logger.exception("Failed to create/update MCP tool")
raise HTTPException(
status_code=500, detail=f"Failed to create/update MCP tool: {str(e)}"
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Failed to create/update MCP tool: {str(e)}",
)
@@ -1786,7 +1797,7 @@ def update_mcp_server_with_tools(
try:
mcp_server = get_mcp_server_by_id(request.server_id, db_session)
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
_ensure_mcp_server_owner_or_admin(mcp_server, user)
@@ -1794,8 +1805,9 @@ def update_mcp_server_with_tools(
MCPAuthenticationType.NONE,
MCPAuthenticationType.PT_OAUTH,
):
raise HTTPException(
status_code=400, detail="MCP server has no admin connection config"
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"MCP server has no admin connection config",
)
name_changed = request.name is not None and request.name != mcp_server.name
@@ -1877,7 +1889,7 @@ def update_mcp_server_simple(
try:
mcp_server = get_mcp_server_by_id(server_id, db_session)
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
_ensure_mcp_server_owner_or_admin(mcp_server, user)
@@ -1938,7 +1950,7 @@ def delete_mcp_server_admin(
return {"success": True}
except ValueError:
raise HTTPException(status_code=404, detail="MCP server not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "MCP server not found")
except Exception as e:
logger.error(f"Failed to delete MCP server {server_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to delete MCP server")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to delete MCP server")