[refactor] Use Pydantic models for query parameter validation

- Added query parameter models to OpenAPI spec for GET endpoints
- Regenerated data models to include new query param models
- Replaced manual validation with Pydantic model validation
- Removed obsolete validate_required_params helper function
- Provides better error messages and type safety for API endpoints

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
bymyself 2025-06-17 14:42:25 -07:00
parent f1b3c6b735
commit a4bf6bddbf
4 changed files with 111 additions and 54 deletions

View File

@ -49,6 +49,9 @@ from .generated_models import (
UninstallPackParams,
DisablePackParams,
EnablePackParams,
UpdateAllQueryParams,
UpdateComfyUIQueryParams,
ComfyUISwitchVersionQueryParams,
QueueStatus,
ManagerMappings,
ModelMetadata,
@ -104,6 +107,9 @@ __all__ = [
"UninstallPackParams",
"DisablePackParams",
"EnablePackParams",
"UpdateAllQueryParams",
"UpdateComfyUIQueryParams",
"ComfyUISwitchVersionQueryParams",
"QueueStatus",
"ManagerMappings",
"ModelMetadata",

View File

@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2025-06-17T20:27:16+00:00
# timestamp: 2025-06-17T21:37:15+00:00
from __future__ import annotations
@ -252,6 +252,33 @@ class EnablePackParams(BaseModel):
)
class UpdateAllQueryParams(BaseModel):
client_id: str = Field(
..., description='Client identifier that initiated the request'
)
ui_id: str = Field(..., description='Base UI identifier for task tracking')
mode: Optional[ManagerDatabaseSource] = None
class UpdateComfyUIQueryParams(BaseModel):
client_id: str = Field(
..., description='Client identifier that initiated the request'
)
ui_id: str = Field(..., description='UI identifier for task tracking')
stable: Optional[bool] = Field(
True,
description='Whether to update to stable version (true) or nightly (false)',
)
class ComfyUISwitchVersionQueryParams(BaseModel):
ver: str = Field(..., description='Version to switch to')
client_id: str = Field(
..., description='Client identifier that initiated the request'
)
ui_id: str = Field(..., description='UI identifier for task tracking')
class QueueStatus(BaseModel):
total_count: int = Field(
..., description='Total number of tasks (pending + running)'

View File

@ -75,6 +75,9 @@ from ..data_models import (
OperationResult,
ManagerDatabaseSource,
SecurityLevel,
UpdateAllQueryParams,
UpdateComfyUIQueryParams,
ComfyUISwitchVersionQueryParams,
)
from .constants import (
@ -102,30 +105,6 @@ def is_loopback(address):
return False
def validate_required_params(request: web.Request, required_params: List[str]) -> Optional[web.Response]:
"""Validate that all required query parameters are present.
Args:
request: The aiohttp request object
required_params: List of required parameter names
Returns:
web.Response with 400 status if validation fails, None if validation passes
"""
missing_params = []
for param in required_params:
if param not in request.rel_url.query:
missing_params.append(param)
if missing_params:
missing_str = ", ".join(missing_params)
return web.Response(
status=400,
text=f"Missing required parameter(s): {missing_str}"
)
return None
def error_response(status: int, message: str, error_type: Optional[str] = None) -> web.Response:
"""Create a standardized error response.
@ -1269,31 +1248,30 @@ async def fetch_updates(request):
@routes.get("/v2/manager/queue/update_all")
async def update_all(request: web.Request) -> web.Response:
# Validate required query parameters
validation_error = validate_required_params(request, ["client_id", "ui_id"])
if validation_error:
return validation_error
json_data = dict(request.rel_url.query)
return await _update_all(json_data)
try:
# Validate query parameters using Pydantic model
query_params = UpdateAllQueryParams.model_validate(dict(request.rel_url.query))
return await _update_all(query_params)
except ValidationError as e:
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
async def _update_all(json_data: Dict[str, Any]) -> web.Response:
async def _update_all(params: UpdateAllQueryParams) -> web.Response:
if not security_utils.is_allowed_security_level("middle"):
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
return web.Response(status=403)
# Extract client info
base_ui_id = json_data["ui_id"]
client_id = json_data["client_id"]
mode = json_data.get("mode", "remote")
# Extract client info from validated params
base_ui_id = params.ui_id
client_id = params.client_id
mode = params.mode.value if params.mode else ManagerDatabaseSource.remote.value
logging.debug(
"[ComfyUI-Manager] Update all requested: client_id=%s, base_ui_id=%s, mode=%s",
client_id, base_ui_id, mode
)
if mode == "local":
if mode == ManagerDatabaseSource.local.value:
channel = "local"
else:
channel = core.get_config()["channel_url"]
@ -1586,14 +1564,20 @@ async def queue_start(request):
@routes.get("/v2/manager/queue/update_comfyui")
async def update_comfyui(request):
"""Queue a ComfyUI update based on the configured update policy."""
# Validate required query parameters
validation_error = validate_required_params(request, ["client_id", "ui_id"])
if validation_error:
return validation_error
is_stable = core.get_config()["update_policy"] != "nightly-comfyui"
client_id = request.rel_url.query["client_id"]
ui_id = request.rel_url.query["ui_id"]
try:
# Validate query parameters using Pydantic model
query_params = UpdateComfyUIQueryParams.model_validate(dict(request.rel_url.query))
# Check if stable parameter was provided, otherwise use config
if query_params.stable is None:
is_stable = core.get_config()["update_policy"] != "nightly-comfyui"
else:
is_stable = query_params.stable
client_id = query_params.client_id
ui_id = query_params.ui_id
except ValidationError as e:
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
# Create update-comfyui task
task = QueueTaskItem(
@ -1625,14 +1609,12 @@ async def comfyui_versions(request):
@routes.get("/v2/comfyui_manager/comfyui_switch_version")
async def comfyui_switch_version(request):
try:
# Validate required query parameters
validation_error = validate_required_params(request, ["ver", "client_id", "ui_id"])
if validation_error:
return validation_error
target_version = request.rel_url.query["ver"]
client_id = request.rel_url.query["client_id"]
ui_id = request.rel_url.query["ui_id"]
# Validate query parameters using Pydantic model
query_params = ComfyUISwitchVersionQueryParams.model_validate(dict(request.rel_url.query))
target_version = query_params.ver
client_id = query_params.client_id
ui_id = query_params.ui_id
# Create update-comfyui task with target version
task = QueueTaskItem(
@ -1644,6 +1626,8 @@ async def comfyui_switch_version(request):
task_queue.put(task)
return web.Response(status=200)
except ValidationError as e:
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
except Exception as e:
logging.error(f"ComfyUI version switch fail: {e}", file=sys.stderr)
return web.Response(status=400)

View File

@ -378,6 +378,46 @@ components:
type: string
description: ComfyUI Node Registry ID of the package to enable
required: [cnr_id]
# Query Parameter Models
UpdateAllQueryParams:
type: object
properties:
client_id:
type: string
description: Client identifier that initiated the request
ui_id:
type: string
description: Base UI identifier for task tracking
mode:
$ref: '#/components/schemas/ManagerDatabaseSource'
required: [client_id, ui_id]
UpdateComfyUIQueryParams:
type: object
properties:
client_id:
type: string
description: Client identifier that initiated the request
ui_id:
type: string
description: UI identifier for task tracking
stable:
type: boolean
default: true
description: Whether to update to stable version (true) or nightly (false)
required: [client_id, ui_id]
ComfyUISwitchVersionQueryParams:
type: object
properties:
ver:
type: string
description: Version to switch to
client_id:
type: string
description: Client identifier that initiated the request
ui_id:
type: string
description: UI identifier for task tracking
required: [ver, client_id, ui_id]
# Queue Status Models
QueueStatus:
type: object