mirror of
https://git.datalinker.icu/ltdrdata/ComfyUI-Manager
synced 2025-12-09 22:24:23 +08:00
[refactor] Use Pydantic models for query parameter validation
- Added query parameter models to OpenAPI spec for GET endpoints - Regenerated data models to include new query param models - Replaced manual validation with Pydantic model validation - Removed obsolete validate_required_params helper function - Provides better error messages and type safety for API endpoints Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
f1b3c6b735
commit
a4bf6bddbf
@ -49,6 +49,9 @@ from .generated_models import (
|
||||
UninstallPackParams,
|
||||
DisablePackParams,
|
||||
EnablePackParams,
|
||||
UpdateAllQueryParams,
|
||||
UpdateComfyUIQueryParams,
|
||||
ComfyUISwitchVersionQueryParams,
|
||||
QueueStatus,
|
||||
ManagerMappings,
|
||||
ModelMetadata,
|
||||
@ -104,6 +107,9 @@ __all__ = [
|
||||
"UninstallPackParams",
|
||||
"DisablePackParams",
|
||||
"EnablePackParams",
|
||||
"UpdateAllQueryParams",
|
||||
"UpdateComfyUIQueryParams",
|
||||
"ComfyUISwitchVersionQueryParams",
|
||||
"QueueStatus",
|
||||
"ManagerMappings",
|
||||
"ModelMetadata",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: openapi.yaml
|
||||
# timestamp: 2025-06-17T20:27:16+00:00
|
||||
# timestamp: 2025-06-17T21:37:15+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -252,6 +252,33 @@ class EnablePackParams(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class UpdateAllQueryParams(BaseModel):
|
||||
client_id: str = Field(
|
||||
..., description='Client identifier that initiated the request'
|
||||
)
|
||||
ui_id: str = Field(..., description='Base UI identifier for task tracking')
|
||||
mode: Optional[ManagerDatabaseSource] = None
|
||||
|
||||
|
||||
class UpdateComfyUIQueryParams(BaseModel):
|
||||
client_id: str = Field(
|
||||
..., description='Client identifier that initiated the request'
|
||||
)
|
||||
ui_id: str = Field(..., description='UI identifier for task tracking')
|
||||
stable: Optional[bool] = Field(
|
||||
True,
|
||||
description='Whether to update to stable version (true) or nightly (false)',
|
||||
)
|
||||
|
||||
|
||||
class ComfyUISwitchVersionQueryParams(BaseModel):
|
||||
ver: str = Field(..., description='Version to switch to')
|
||||
client_id: str = Field(
|
||||
..., description='Client identifier that initiated the request'
|
||||
)
|
||||
ui_id: str = Field(..., description='UI identifier for task tracking')
|
||||
|
||||
|
||||
class QueueStatus(BaseModel):
|
||||
total_count: int = Field(
|
||||
..., description='Total number of tasks (pending + running)'
|
||||
|
||||
@ -75,6 +75,9 @@ from ..data_models import (
|
||||
OperationResult,
|
||||
ManagerDatabaseSource,
|
||||
SecurityLevel,
|
||||
UpdateAllQueryParams,
|
||||
UpdateComfyUIQueryParams,
|
||||
ComfyUISwitchVersionQueryParams,
|
||||
)
|
||||
|
||||
from .constants import (
|
||||
@ -102,30 +105,6 @@ def is_loopback(address):
|
||||
return False
|
||||
|
||||
|
||||
def validate_required_params(request: web.Request, required_params: List[str]) -> Optional[web.Response]:
|
||||
"""Validate that all required query parameters are present.
|
||||
|
||||
Args:
|
||||
request: The aiohttp request object
|
||||
required_params: List of required parameter names
|
||||
|
||||
Returns:
|
||||
web.Response with 400 status if validation fails, None if validation passes
|
||||
"""
|
||||
missing_params = []
|
||||
for param in required_params:
|
||||
if param not in request.rel_url.query:
|
||||
missing_params.append(param)
|
||||
|
||||
if missing_params:
|
||||
missing_str = ", ".join(missing_params)
|
||||
return web.Response(
|
||||
status=400,
|
||||
text=f"Missing required parameter(s): {missing_str}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def error_response(status: int, message: str, error_type: Optional[str] = None) -> web.Response:
|
||||
"""Create a standardized error response.
|
||||
|
||||
@ -1269,31 +1248,30 @@ async def fetch_updates(request):
|
||||
|
||||
@routes.get("/v2/manager/queue/update_all")
|
||||
async def update_all(request: web.Request) -> web.Response:
|
||||
# Validate required query parameters
|
||||
validation_error = validate_required_params(request, ["client_id", "ui_id"])
|
||||
if validation_error:
|
||||
return validation_error
|
||||
|
||||
json_data = dict(request.rel_url.query)
|
||||
return await _update_all(json_data)
|
||||
try:
|
||||
# Validate query parameters using Pydantic model
|
||||
query_params = UpdateAllQueryParams.model_validate(dict(request.rel_url.query))
|
||||
return await _update_all(query_params)
|
||||
except ValidationError as e:
|
||||
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
|
||||
|
||||
|
||||
async def _update_all(json_data: Dict[str, Any]) -> web.Response:
|
||||
async def _update_all(params: UpdateAllQueryParams) -> web.Response:
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(status=403)
|
||||
|
||||
# Extract client info
|
||||
base_ui_id = json_data["ui_id"]
|
||||
client_id = json_data["client_id"]
|
||||
mode = json_data.get("mode", "remote")
|
||||
# Extract client info from validated params
|
||||
base_ui_id = params.ui_id
|
||||
client_id = params.client_id
|
||||
mode = params.mode.value if params.mode else ManagerDatabaseSource.remote.value
|
||||
|
||||
logging.debug(
|
||||
"[ComfyUI-Manager] Update all requested: client_id=%s, base_ui_id=%s, mode=%s",
|
||||
client_id, base_ui_id, mode
|
||||
)
|
||||
|
||||
if mode == "local":
|
||||
if mode == ManagerDatabaseSource.local.value:
|
||||
channel = "local"
|
||||
else:
|
||||
channel = core.get_config()["channel_url"]
|
||||
@ -1586,14 +1564,20 @@ async def queue_start(request):
|
||||
@routes.get("/v2/manager/queue/update_comfyui")
|
||||
async def update_comfyui(request):
|
||||
"""Queue a ComfyUI update based on the configured update policy."""
|
||||
# Validate required query parameters
|
||||
validation_error = validate_required_params(request, ["client_id", "ui_id"])
|
||||
if validation_error:
|
||||
return validation_error
|
||||
try:
|
||||
# Validate query parameters using Pydantic model
|
||||
query_params = UpdateComfyUIQueryParams.model_validate(dict(request.rel_url.query))
|
||||
|
||||
# Check if stable parameter was provided, otherwise use config
|
||||
if query_params.stable is None:
|
||||
is_stable = core.get_config()["update_policy"] != "nightly-comfyui"
|
||||
client_id = request.rel_url.query["client_id"]
|
||||
ui_id = request.rel_url.query["ui_id"]
|
||||
else:
|
||||
is_stable = query_params.stable
|
||||
|
||||
client_id = query_params.client_id
|
||||
ui_id = query_params.ui_id
|
||||
except ValidationError as e:
|
||||
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
|
||||
|
||||
# Create update-comfyui task
|
||||
task = QueueTaskItem(
|
||||
@ -1625,14 +1609,12 @@ async def comfyui_versions(request):
|
||||
@routes.get("/v2/comfyui_manager/comfyui_switch_version")
|
||||
async def comfyui_switch_version(request):
|
||||
try:
|
||||
# Validate required query parameters
|
||||
validation_error = validate_required_params(request, ["ver", "client_id", "ui_id"])
|
||||
if validation_error:
|
||||
return validation_error
|
||||
# Validate query parameters using Pydantic model
|
||||
query_params = ComfyUISwitchVersionQueryParams.model_validate(dict(request.rel_url.query))
|
||||
|
||||
target_version = request.rel_url.query["ver"]
|
||||
client_id = request.rel_url.query["client_id"]
|
||||
ui_id = request.rel_url.query["ui_id"]
|
||||
target_version = query_params.ver
|
||||
client_id = query_params.client_id
|
||||
ui_id = query_params.ui_id
|
||||
|
||||
# Create update-comfyui task with target version
|
||||
task = QueueTaskItem(
|
||||
@ -1644,6 +1626,8 @@ async def comfyui_switch_version(request):
|
||||
|
||||
task_queue.put(task)
|
||||
return web.Response(status=200)
|
||||
except ValidationError as e:
|
||||
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
|
||||
except Exception as e:
|
||||
logging.error(f"ComfyUI version switch fail: {e}", file=sys.stderr)
|
||||
return web.Response(status=400)
|
||||
|
||||
40
openapi.yaml
40
openapi.yaml
@ -378,6 +378,46 @@ components:
|
||||
type: string
|
||||
description: ComfyUI Node Registry ID of the package to enable
|
||||
required: [cnr_id]
|
||||
# Query Parameter Models
|
||||
UpdateAllQueryParams:
|
||||
type: object
|
||||
properties:
|
||||
client_id:
|
||||
type: string
|
||||
description: Client identifier that initiated the request
|
||||
ui_id:
|
||||
type: string
|
||||
description: Base UI identifier for task tracking
|
||||
mode:
|
||||
$ref: '#/components/schemas/ManagerDatabaseSource'
|
||||
required: [client_id, ui_id]
|
||||
UpdateComfyUIQueryParams:
|
||||
type: object
|
||||
properties:
|
||||
client_id:
|
||||
type: string
|
||||
description: Client identifier that initiated the request
|
||||
ui_id:
|
||||
type: string
|
||||
description: UI identifier for task tracking
|
||||
stable:
|
||||
type: boolean
|
||||
default: true
|
||||
description: Whether to update to stable version (true) or nightly (false)
|
||||
required: [client_id, ui_id]
|
||||
ComfyUISwitchVersionQueryParams:
|
||||
type: object
|
||||
properties:
|
||||
ver:
|
||||
type: string
|
||||
description: Version to switch to
|
||||
client_id:
|
||||
type: string
|
||||
description: Client identifier that initiated the request
|
||||
ui_id:
|
||||
type: string
|
||||
description: UI identifier for task tracking
|
||||
required: [ver, client_id, ui_id]
|
||||
# Queue Status Models
|
||||
QueueStatus:
|
||||
type: object
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user