[refactor] Use Pydantic models for query parameter validation

- Added query parameter models to OpenAPI spec for GET endpoints
- Regenerated data models to include new query param models
- Replaced manual validation with Pydantic model validation
- Removed obsolete validate_required_params helper function
- Provides better error messages and type safety for API endpoints

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
bymyself 2025-06-17 14:42:25 -07:00
parent f1b3c6b735
commit a4bf6bddbf
4 changed files with 111 additions and 54 deletions

View File

@ -49,6 +49,9 @@ from .generated_models import (
UninstallPackParams, UninstallPackParams,
DisablePackParams, DisablePackParams,
EnablePackParams, EnablePackParams,
UpdateAllQueryParams,
UpdateComfyUIQueryParams,
ComfyUISwitchVersionQueryParams,
QueueStatus, QueueStatus,
ManagerMappings, ManagerMappings,
ModelMetadata, ModelMetadata,
@ -104,6 +107,9 @@ __all__ = [
"UninstallPackParams", "UninstallPackParams",
"DisablePackParams", "DisablePackParams",
"EnablePackParams", "EnablePackParams",
"UpdateAllQueryParams",
"UpdateComfyUIQueryParams",
"ComfyUISwitchVersionQueryParams",
"QueueStatus", "QueueStatus",
"ManagerMappings", "ManagerMappings",
"ModelMetadata", "ModelMetadata",

View File

@ -1,6 +1,6 @@
# generated by datamodel-codegen: # generated by datamodel-codegen:
# filename: openapi.yaml # filename: openapi.yaml
# timestamp: 2025-06-17T20:27:16+00:00 # timestamp: 2025-06-17T21:37:15+00:00
from __future__ import annotations from __future__ import annotations
@ -252,6 +252,33 @@ class EnablePackParams(BaseModel):
) )
class UpdateAllQueryParams(BaseModel):
client_id: str = Field(
..., description='Client identifier that initiated the request'
)
ui_id: str = Field(..., description='Base UI identifier for task tracking')
mode: Optional[ManagerDatabaseSource] = None
class UpdateComfyUIQueryParams(BaseModel):
client_id: str = Field(
..., description='Client identifier that initiated the request'
)
ui_id: str = Field(..., description='UI identifier for task tracking')
stable: Optional[bool] = Field(
True,
description='Whether to update to stable version (true) or nightly (false)',
)
class ComfyUISwitchVersionQueryParams(BaseModel):
ver: str = Field(..., description='Version to switch to')
client_id: str = Field(
..., description='Client identifier that initiated the request'
)
ui_id: str = Field(..., description='UI identifier for task tracking')
class QueueStatus(BaseModel): class QueueStatus(BaseModel):
total_count: int = Field( total_count: int = Field(
..., description='Total number of tasks (pending + running)' ..., description='Total number of tasks (pending + running)'

View File

@ -75,6 +75,9 @@ from ..data_models import (
OperationResult, OperationResult,
ManagerDatabaseSource, ManagerDatabaseSource,
SecurityLevel, SecurityLevel,
UpdateAllQueryParams,
UpdateComfyUIQueryParams,
ComfyUISwitchVersionQueryParams,
) )
from .constants import ( from .constants import (
@ -102,30 +105,6 @@ def is_loopback(address):
return False return False
def validate_required_params(request: web.Request, required_params: List[str]) -> Optional[web.Response]:
"""Validate that all required query parameters are present.
Args:
request: The aiohttp request object
required_params: List of required parameter names
Returns:
web.Response with 400 status if validation fails, None if validation passes
"""
missing_params = []
for param in required_params:
if param not in request.rel_url.query:
missing_params.append(param)
if missing_params:
missing_str = ", ".join(missing_params)
return web.Response(
status=400,
text=f"Missing required parameter(s): {missing_str}"
)
return None
def error_response(status: int, message: str, error_type: Optional[str] = None) -> web.Response: def error_response(status: int, message: str, error_type: Optional[str] = None) -> web.Response:
"""Create a standardized error response. """Create a standardized error response.
@ -1269,31 +1248,30 @@ async def fetch_updates(request):
@routes.get("/v2/manager/queue/update_all") @routes.get("/v2/manager/queue/update_all")
async def update_all(request: web.Request) -> web.Response: async def update_all(request: web.Request) -> web.Response:
# Validate required query parameters try:
validation_error = validate_required_params(request, ["client_id", "ui_id"]) # Validate query parameters using Pydantic model
if validation_error: query_params = UpdateAllQueryParams.model_validate(dict(request.rel_url.query))
return validation_error return await _update_all(query_params)
except ValidationError as e:
json_data = dict(request.rel_url.query) return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
return await _update_all(json_data)
async def _update_all(json_data: Dict[str, Any]) -> web.Response: async def _update_all(params: UpdateAllQueryParams) -> web.Response:
if not security_utils.is_allowed_security_level("middle"): if not security_utils.is_allowed_security_level("middle"):
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW) logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
return web.Response(status=403) return web.Response(status=403)
# Extract client info # Extract client info from validated params
base_ui_id = json_data["ui_id"] base_ui_id = params.ui_id
client_id = json_data["client_id"] client_id = params.client_id
mode = json_data.get("mode", "remote") mode = params.mode.value if params.mode else ManagerDatabaseSource.remote.value
logging.debug( logging.debug(
"[ComfyUI-Manager] Update all requested: client_id=%s, base_ui_id=%s, mode=%s", "[ComfyUI-Manager] Update all requested: client_id=%s, base_ui_id=%s, mode=%s",
client_id, base_ui_id, mode client_id, base_ui_id, mode
) )
if mode == "local": if mode == ManagerDatabaseSource.local.value:
channel = "local" channel = "local"
else: else:
channel = core.get_config()["channel_url"] channel = core.get_config()["channel_url"]
@ -1586,14 +1564,20 @@ async def queue_start(request):
@routes.get("/v2/manager/queue/update_comfyui") @routes.get("/v2/manager/queue/update_comfyui")
async def update_comfyui(request): async def update_comfyui(request):
"""Queue a ComfyUI update based on the configured update policy.""" """Queue a ComfyUI update based on the configured update policy."""
# Validate required query parameters try:
validation_error = validate_required_params(request, ["client_id", "ui_id"]) # Validate query parameters using Pydantic model
if validation_error: query_params = UpdateComfyUIQueryParams.model_validate(dict(request.rel_url.query))
return validation_error
# Check if stable parameter was provided, otherwise use config
if query_params.stable is None:
is_stable = core.get_config()["update_policy"] != "nightly-comfyui" is_stable = core.get_config()["update_policy"] != "nightly-comfyui"
client_id = request.rel_url.query["client_id"] else:
ui_id = request.rel_url.query["ui_id"] is_stable = query_params.stable
client_id = query_params.client_id
ui_id = query_params.ui_id
except ValidationError as e:
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
# Create update-comfyui task # Create update-comfyui task
task = QueueTaskItem( task = QueueTaskItem(
@ -1625,14 +1609,12 @@ async def comfyui_versions(request):
@routes.get("/v2/comfyui_manager/comfyui_switch_version") @routes.get("/v2/comfyui_manager/comfyui_switch_version")
async def comfyui_switch_version(request): async def comfyui_switch_version(request):
try: try:
# Validate required query parameters # Validate query parameters using Pydantic model
validation_error = validate_required_params(request, ["ver", "client_id", "ui_id"]) query_params = ComfyUISwitchVersionQueryParams.model_validate(dict(request.rel_url.query))
if validation_error:
return validation_error
target_version = request.rel_url.query["ver"] target_version = query_params.ver
client_id = request.rel_url.query["client_id"] client_id = query_params.client_id
ui_id = request.rel_url.query["ui_id"] ui_id = query_params.ui_id
# Create update-comfyui task with target version # Create update-comfyui task with target version
task = QueueTaskItem( task = QueueTaskItem(
@ -1644,6 +1626,8 @@ async def comfyui_switch_version(request):
task_queue.put(task) task_queue.put(task)
return web.Response(status=200) return web.Response(status=200)
except ValidationError as e:
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
except Exception as e: except Exception as e:
logging.error(f"ComfyUI version switch fail: {e}", file=sys.stderr) logging.error(f"ComfyUI version switch fail: {e}", file=sys.stderr)
return web.Response(status=400) return web.Response(status=400)

View File

@ -378,6 +378,46 @@ components:
type: string type: string
description: ComfyUI Node Registry ID of the package to enable description: ComfyUI Node Registry ID of the package to enable
required: [cnr_id] required: [cnr_id]
# Query Parameter Models
UpdateAllQueryParams:
type: object
properties:
client_id:
type: string
description: Client identifier that initiated the request
ui_id:
type: string
description: Base UI identifier for task tracking
mode:
$ref: '#/components/schemas/ManagerDatabaseSource'
required: [client_id, ui_id]
UpdateComfyUIQueryParams:
type: object
properties:
client_id:
type: string
description: Client identifier that initiated the request
ui_id:
type: string
description: UI identifier for task tracking
stable:
type: boolean
default: true
description: Whether to update to stable version (true) or nightly (false)
required: [client_id, ui_id]
ComfyUISwitchVersionQueryParams:
type: object
properties:
ver:
type: string
description: Version to switch to
client_id:
type: string
description: Client identifier that initiated the request
ui_id:
type: string
description: UI identifier for task tracking
required: [ver, client_id, ui_id]
# Queue Status Models # Queue Status Models
QueueStatus: QueueStatus:
type: object type: object