[feat] Add comprehensive Pydantic validation to all API endpoints

- Updated all POST endpoints to use proper Pydantic model validation:
  - `/v2/manager/queue/task` - validates QueueTaskItem
  - `/v2/manager/queue/install_model` - validates ModelMetadata
  - `/v2/manager/queue/reinstall` - validates InstallPackParams
  - `/v2/customnode/import_fail_info` - validates cnr_id/url fields

- Added proper error handling with ValidationError for detailed error messages
- Updated TaskQueue.put() to handle both dict and Pydantic model inputs
- Added missing imports: InstallPackParams, ModelMetadata, ValidationError

Benefits:
- Early validation catches invalid data at API boundaries
- Better error messages for clients with specific validation failures
- Type safety throughout the request processing pipeline
- Consistent validation behavior across all endpoints

All ruff checks pass and validation is now enabled by default.
This commit is contained in:
bymyself 2025-06-08 01:50:36 -07:00
parent 7f1ebbe081
commit 884b503728

View File

@ -16,6 +16,7 @@ from datetime import datetime
import heapq
import copy
from typing import NamedTuple, List, Literal, Optional
from pydantic import ValidationError
from comfy.cli_args import args
import latent_preview
from aiohttp import web
@ -57,6 +58,8 @@ from ..data_models import (
InstalledNodeInfo,
InstalledModelInfo,
ComfyUIVersionInfo,
InstallPackParams,
ModelMetadata,
)
from .constants import (
@ -208,12 +211,17 @@ class TaskQueue:
"""
PromptServer.instance.send_sync(msg, update.model_dump(), client_id)
def put(self, item: QueueTaskItem) -> None:
def put(self, item) -> None:
"""Add a task to the queue. Item can be a dict or QueueTaskItem model."""
with self.mutex:
# Start a new batch if this is the first task after queue was empty
if self.batch_id is None and len(self.pending_tasks) == 0 and len(self.running_tasks) == 0:
self._start_new_batch()
# Convert to dict if it's a Pydantic model
if hasattr(item, 'model_dump'):
item = item.model_dump()
heapq.heappush(self.pending_tasks, item)
self.not_empty.notify()
@ -911,12 +919,21 @@ async def queue_task(request) -> web.Response:
request: aiohttp request containing JSON task data
Returns:
web.Response: HTTP 200 on successful queueing
web.Response: HTTP 200 on successful queueing, HTTP 400 on validation error
"""
json_data = await request.json()
TaskQueue.instance.put(json_data)
# maybe start worker
return web.Response(status=200)
try:
json_data = await request.json()
# Validate input using Pydantic model
task_item = QueueTaskItem.model_validate(json_data)
TaskQueue.instance.put(task_item)
# maybe start worker
return web.Response(status=200)
except ValidationError as e:
logging.error(f"[ComfyUI-Manager] Invalid task data: {e}")
return web.Response(status=400, text=f"Invalid task data: {e}")
except Exception as e:
logging.error(f"[ComfyUI-Manager] Error processing task: {e}")
return web.Response(status=500, text="Internal server error")
@routes.get("/v2/manager/queue/history_list")
@ -1365,25 +1382,52 @@ def unzip_install(files):
@routes.post("/v2/customnode/import_fail_info")
async def import_fail_info(request):
json_data = await request.json()
try:
json_data = await request.json()
# Basic validation - ensure we have either cnr_id or url
if not isinstance(json_data, dict):
return web.Response(status=400, text="Request body must be a JSON object")
if "cnr_id" not in json_data and "url" not in json_data:
return web.Response(status=400, text="Either 'cnr_id' or 'url' field is required")
if "cnr_id" in json_data:
module_name = core.unified_manager.get_module_name(json_data["cnr_id"])
else:
module_name = core.unified_manager.get_module_name(json_data["url"])
if "cnr_id" in json_data:
if not isinstance(json_data["cnr_id"], str):
return web.Response(status=400, text="'cnr_id' must be a string")
module_name = core.unified_manager.get_module_name(json_data["cnr_id"])
else:
if not isinstance(json_data["url"], str):
return web.Response(status=400, text="'url' must be a string")
module_name = core.unified_manager.get_module_name(json_data["url"])
if module_name is not None:
info = cm_global.error_dict.get(module_name)
if info is not None:
return web.json_response(info)
if module_name is not None:
info = cm_global.error_dict.get(module_name)
if info is not None:
return web.json_response(info)
return web.Response(status=400)
return web.Response(status=400)
except Exception as e:
logging.error(f"[ComfyUI-Manager] Error processing import fail info: {e}")
return web.Response(status=500, text="Internal server error")
@routes.post("/v2/manager/queue/reinstall")
async def reinstall_custom_node(request):
await _uninstall_custom_node(await request.json())
await _install_custom_node(await request.json())
try:
json_data = await request.json()
# Validate input using Pydantic model
pack_data = InstallPackParams.model_validate(json_data)
validated_data = pack_data.model_dump()
await _uninstall_custom_node(validated_data)
await _install_custom_node(validated_data)
return web.Response(status=200)
except ValidationError as e:
logging.error(f"[ComfyUI-Manager] Invalid pack data: {e}")
return web.Response(status=400, text=f"Invalid pack data: {e}")
except Exception as e:
logging.error(f"[ComfyUI-Manager] Error processing reinstall: {e}")
return web.Response(status=500, text="Internal server error")
@routes.get("/v2/manager/queue/reset")
@ -1735,8 +1779,17 @@ async def check_whitelist_for_model(item):
@routes.post("/v2/manager/queue/install_model")
async def install_model(request):
json_data = await request.json()
return await _install_model(json_data)
try:
json_data = await request.json()
# Validate input using Pydantic model
model_data = ModelMetadata.model_validate(json_data)
return await _install_model(model_data.model_dump())
except ValidationError as e:
logging.error(f"[ComfyUI-Manager] Invalid model data: {e}")
return web.Response(status=400, text=f"Invalid model data: {e}")
except Exception as e:
logging.error(f"[ComfyUI-Manager] Error processing model install: {e}")
return web.Response(status=500, text="Internal server error")
async def _install_model(json_data):