diff --git a/comfyui_manager/data_models/__init__.py b/comfyui_manager/data_models/__init__.py index be99fd36..5d115e6a 100644 --- a/comfyui_manager/data_models/__init__.py +++ b/comfyui_manager/data_models/__init__.py @@ -49,6 +49,9 @@ from .generated_models import ( UninstallPackParams, DisablePackParams, EnablePackParams, + UpdateAllQueryParams, + UpdateComfyUIQueryParams, + ComfyUISwitchVersionQueryParams, QueueStatus, ManagerMappings, ModelMetadata, @@ -104,6 +107,9 @@ __all__ = [ "UninstallPackParams", "DisablePackParams", "EnablePackParams", + "UpdateAllQueryParams", + "UpdateComfyUIQueryParams", + "ComfyUISwitchVersionQueryParams", "QueueStatus", "ManagerMappings", "ModelMetadata", diff --git a/comfyui_manager/data_models/generated_models.py b/comfyui_manager/data_models/generated_models.py index 76e2bf13..6e3349f1 100644 --- a/comfyui_manager/data_models/generated_models.py +++ b/comfyui_manager/data_models/generated_models.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2025-06-17T20:27:16+00:00 +# timestamp: 2025-06-17T21:37:15+00:00 from __future__ import annotations @@ -252,6 +252,33 @@ class EnablePackParams(BaseModel): ) +class UpdateAllQueryParams(BaseModel): + client_id: str = Field( + ..., description='Client identifier that initiated the request' + ) + ui_id: str = Field(..., description='Base UI identifier for task tracking') + mode: Optional[ManagerDatabaseSource] = None + + +class UpdateComfyUIQueryParams(BaseModel): + client_id: str = Field( + ..., description='Client identifier that initiated the request' + ) + ui_id: str = Field(..., description='UI identifier for task tracking') + stable: Optional[bool] = Field( + True, + description='Whether to update to stable version (true) or nightly (false)', + ) + + +class ComfyUISwitchVersionQueryParams(BaseModel): + ver: str = Field(..., description='Version to switch to') + client_id: str = Field( + ..., description='Client identifier that initiated the request' + ) + ui_id: str = Field(..., description='UI identifier for task tracking') + + class QueueStatus(BaseModel): total_count: int = Field( ..., description='Total number of tasks (pending + running)' diff --git a/comfyui_manager/glob/manager_server.py b/comfyui_manager/glob/manager_server.py index b0842273..84f23920 100644 --- a/comfyui_manager/glob/manager_server.py +++ b/comfyui_manager/glob/manager_server.py @@ -75,6 +75,9 @@ from ..data_models import ( OperationResult, ManagerDatabaseSource, SecurityLevel, + UpdateAllQueryParams, + UpdateComfyUIQueryParams, + ComfyUISwitchVersionQueryParams, ) from .constants import ( @@ -102,30 +105,6 @@ def is_loopback(address): return False -def validate_required_params(request: web.Request, required_params: List[str]) -> Optional[web.Response]: - """Validate that all required query parameters are present. - - Args: - request: The aiohttp request object - required_params: List of required parameter names - - Returns: - web.Response with 400 status if validation fails, None if validation passes - """ - missing_params = [] - for param in required_params: - if param not in request.rel_url.query: - missing_params.append(param) - - if missing_params: - missing_str = ", ".join(missing_params) - return web.Response( - status=400, - text=f"Missing required parameter(s): {missing_str}" - ) - return None - - def error_response(status: int, message: str, error_type: Optional[str] = None) -> web.Response: """Create a standardized error response. @@ -1269,31 +1248,30 @@ async def fetch_updates(request): @routes.get("/v2/manager/queue/update_all") async def update_all(request: web.Request) -> web.Response: - # Validate required query parameters - validation_error = validate_required_params(request, ["client_id", "ui_id"]) - if validation_error: - return validation_error - - json_data = dict(request.rel_url.query) - return await _update_all(json_data) + try: + # Validate query parameters using Pydantic model + query_params = UpdateAllQueryParams.model_validate(dict(request.rel_url.query)) + return await _update_all(query_params) + except ValidationError as e: + return web.json_response({"error": "Validation error", "details": e.errors()}, status=400) -async def _update_all(json_data: Dict[str, Any]) -> web.Response: +async def _update_all(params: UpdateAllQueryParams) -> web.Response: if not security_utils.is_allowed_security_level("middle"): logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW) return web.Response(status=403) - # Extract client info - base_ui_id = json_data["ui_id"] - client_id = json_data["client_id"] - mode = json_data.get("mode", "remote") + # Extract client info from validated params + base_ui_id = params.ui_id + client_id = params.client_id + mode = params.mode.value if params.mode else ManagerDatabaseSource.remote.value logging.debug( "[ComfyUI-Manager] Update all requested: client_id=%s, base_ui_id=%s, mode=%s", client_id, base_ui_id, mode ) - if mode == "local": + if mode == ManagerDatabaseSource.local.value: channel = "local" else: channel = core.get_config()["channel_url"] @@ -1586,14 +1564,20 @@ async def queue_start(request): @routes.get("/v2/manager/queue/update_comfyui") async def update_comfyui(request): """Queue a ComfyUI update based on the configured update policy.""" - # Validate required query parameters - validation_error = validate_required_params(request, ["client_id", "ui_id"]) - if validation_error: - return validation_error - - is_stable = core.get_config()["update_policy"] != "nightly-comfyui" - client_id = request.rel_url.query["client_id"] - ui_id = request.rel_url.query["ui_id"] + try: + # Validate query parameters using Pydantic model + query_params = UpdateComfyUIQueryParams.model_validate(dict(request.rel_url.query)) + + # Check if stable parameter was provided, otherwise use config + if query_params.stable is None: + is_stable = core.get_config()["update_policy"] != "nightly-comfyui" + else: + is_stable = query_params.stable + + client_id = query_params.client_id + ui_id = query_params.ui_id + except ValidationError as e: + return web.json_response({"error": "Validation error", "details": e.errors()}, status=400) # Create update-comfyui task task = QueueTaskItem( @@ -1625,14 +1609,12 @@ async def comfyui_versions(request): @routes.get("/v2/comfyui_manager/comfyui_switch_version") async def comfyui_switch_version(request): try: - # Validate required query parameters - validation_error = validate_required_params(request, ["ver", "client_id", "ui_id"]) - if validation_error: - return validation_error - - target_version = request.rel_url.query["ver"] - client_id = request.rel_url.query["client_id"] - ui_id = request.rel_url.query["ui_id"] + # Validate query parameters using Pydantic model + query_params = ComfyUISwitchVersionQueryParams.model_validate(dict(request.rel_url.query)) + + target_version = query_params.ver + client_id = query_params.client_id + ui_id = query_params.ui_id # Create update-comfyui task with target version task = QueueTaskItem( @@ -1644,6 +1626,8 @@ async def comfyui_switch_version(request): task_queue.put(task) return web.Response(status=200) + except ValidationError as e: + return web.json_response({"error": "Validation error", "details": e.errors()}, status=400) except Exception as e: logging.error(f"ComfyUI version switch fail: {e}", file=sys.stderr) return web.Response(status=400) diff --git a/openapi.yaml b/openapi.yaml index 991d29af..37a2ba4c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -378,6 +378,46 @@ components: type: string description: ComfyUI Node Registry ID of the package to enable required: [cnr_id] + # Query Parameter Models + UpdateAllQueryParams: + type: object + properties: + client_id: + type: string + description: Client identifier that initiated the request + ui_id: + type: string + description: Base UI identifier for task tracking + mode: + $ref: '#/components/schemas/ManagerDatabaseSource' + required: [client_id, ui_id] + UpdateComfyUIQueryParams: + type: object + properties: + client_id: + type: string + description: Client identifier that initiated the request + ui_id: + type: string + description: UI identifier for task tracking + stable: + type: boolean + default: true + description: Whether to update to stable version (true) or nightly (false) + required: [client_id, ui_id] + ComfyUISwitchVersionQueryParams: + type: object + properties: + ver: + type: string + description: Version to switch to + client_id: + type: string + description: Client identifier that initiated the request + ui_id: + type: string + description: UI identifier for task tracking + required: [ver, client_id, ui_id] # Queue Status Models QueueStatus: type: object