mirror of
https://git.datalinker.icu/ltdrdata/ComfyUI-Manager
synced 2025-12-09 22:24:23 +08:00
[refactor] Use Pydantic models for query parameter validation
- Added query parameter models to OpenAPI spec for GET endpoints - Regenerated data models to include new query param models - Replaced manual validation with Pydantic model validation - Removed obsolete validate_required_params helper function - Provides better error messages and type safety for API endpoints Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
f1b3c6b735
commit
a4bf6bddbf
@ -49,6 +49,9 @@ from .generated_models import (
|
|||||||
UninstallPackParams,
|
UninstallPackParams,
|
||||||
DisablePackParams,
|
DisablePackParams,
|
||||||
EnablePackParams,
|
EnablePackParams,
|
||||||
|
UpdateAllQueryParams,
|
||||||
|
UpdateComfyUIQueryParams,
|
||||||
|
ComfyUISwitchVersionQueryParams,
|
||||||
QueueStatus,
|
QueueStatus,
|
||||||
ManagerMappings,
|
ManagerMappings,
|
||||||
ModelMetadata,
|
ModelMetadata,
|
||||||
@ -104,6 +107,9 @@ __all__ = [
|
|||||||
"UninstallPackParams",
|
"UninstallPackParams",
|
||||||
"DisablePackParams",
|
"DisablePackParams",
|
||||||
"EnablePackParams",
|
"EnablePackParams",
|
||||||
|
"UpdateAllQueryParams",
|
||||||
|
"UpdateComfyUIQueryParams",
|
||||||
|
"ComfyUISwitchVersionQueryParams",
|
||||||
"QueueStatus",
|
"QueueStatus",
|
||||||
"ManagerMappings",
|
"ManagerMappings",
|
||||||
"ModelMetadata",
|
"ModelMetadata",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# generated by datamodel-codegen:
|
# generated by datamodel-codegen:
|
||||||
# filename: openapi.yaml
|
# filename: openapi.yaml
|
||||||
# timestamp: 2025-06-17T20:27:16+00:00
|
# timestamp: 2025-06-17T21:37:15+00:00
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -252,6 +252,33 @@ class EnablePackParams(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAllQueryParams(BaseModel):
|
||||||
|
client_id: str = Field(
|
||||||
|
..., description='Client identifier that initiated the request'
|
||||||
|
)
|
||||||
|
ui_id: str = Field(..., description='Base UI identifier for task tracking')
|
||||||
|
mode: Optional[ManagerDatabaseSource] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateComfyUIQueryParams(BaseModel):
|
||||||
|
client_id: str = Field(
|
||||||
|
..., description='Client identifier that initiated the request'
|
||||||
|
)
|
||||||
|
ui_id: str = Field(..., description='UI identifier for task tracking')
|
||||||
|
stable: Optional[bool] = Field(
|
||||||
|
True,
|
||||||
|
description='Whether to update to stable version (true) or nightly (false)',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUISwitchVersionQueryParams(BaseModel):
|
||||||
|
ver: str = Field(..., description='Version to switch to')
|
||||||
|
client_id: str = Field(
|
||||||
|
..., description='Client identifier that initiated the request'
|
||||||
|
)
|
||||||
|
ui_id: str = Field(..., description='UI identifier for task tracking')
|
||||||
|
|
||||||
|
|
||||||
class QueueStatus(BaseModel):
|
class QueueStatus(BaseModel):
|
||||||
total_count: int = Field(
|
total_count: int = Field(
|
||||||
..., description='Total number of tasks (pending + running)'
|
..., description='Total number of tasks (pending + running)'
|
||||||
|
|||||||
@ -75,6 +75,9 @@ from ..data_models import (
|
|||||||
OperationResult,
|
OperationResult,
|
||||||
ManagerDatabaseSource,
|
ManagerDatabaseSource,
|
||||||
SecurityLevel,
|
SecurityLevel,
|
||||||
|
UpdateAllQueryParams,
|
||||||
|
UpdateComfyUIQueryParams,
|
||||||
|
ComfyUISwitchVersionQueryParams,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .constants import (
|
from .constants import (
|
||||||
@ -102,30 +105,6 @@ def is_loopback(address):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def validate_required_params(request: web.Request, required_params: List[str]) -> Optional[web.Response]:
|
|
||||||
"""Validate that all required query parameters are present.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The aiohttp request object
|
|
||||||
required_params: List of required parameter names
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
web.Response with 400 status if validation fails, None if validation passes
|
|
||||||
"""
|
|
||||||
missing_params = []
|
|
||||||
for param in required_params:
|
|
||||||
if param not in request.rel_url.query:
|
|
||||||
missing_params.append(param)
|
|
||||||
|
|
||||||
if missing_params:
|
|
||||||
missing_str = ", ".join(missing_params)
|
|
||||||
return web.Response(
|
|
||||||
status=400,
|
|
||||||
text=f"Missing required parameter(s): {missing_str}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def error_response(status: int, message: str, error_type: Optional[str] = None) -> web.Response:
|
def error_response(status: int, message: str, error_type: Optional[str] = None) -> web.Response:
|
||||||
"""Create a standardized error response.
|
"""Create a standardized error response.
|
||||||
|
|
||||||
@ -1269,31 +1248,30 @@ async def fetch_updates(request):
|
|||||||
|
|
||||||
@routes.get("/v2/manager/queue/update_all")
|
@routes.get("/v2/manager/queue/update_all")
|
||||||
async def update_all(request: web.Request) -> web.Response:
|
async def update_all(request: web.Request) -> web.Response:
|
||||||
# Validate required query parameters
|
try:
|
||||||
validation_error = validate_required_params(request, ["client_id", "ui_id"])
|
# Validate query parameters using Pydantic model
|
||||||
if validation_error:
|
query_params = UpdateAllQueryParams.model_validate(dict(request.rel_url.query))
|
||||||
return validation_error
|
return await _update_all(query_params)
|
||||||
|
except ValidationError as e:
|
||||||
json_data = dict(request.rel_url.query)
|
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
|
||||||
return await _update_all(json_data)
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_all(json_data: Dict[str, Any]) -> web.Response:
|
async def _update_all(params: UpdateAllQueryParams) -> web.Response:
|
||||||
if not security_utils.is_allowed_security_level("middle"):
|
if not security_utils.is_allowed_security_level("middle"):
|
||||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
|
|
||||||
# Extract client info
|
# Extract client info from validated params
|
||||||
base_ui_id = json_data["ui_id"]
|
base_ui_id = params.ui_id
|
||||||
client_id = json_data["client_id"]
|
client_id = params.client_id
|
||||||
mode = json_data.get("mode", "remote")
|
mode = params.mode.value if params.mode else ManagerDatabaseSource.remote.value
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"[ComfyUI-Manager] Update all requested: client_id=%s, base_ui_id=%s, mode=%s",
|
"[ComfyUI-Manager] Update all requested: client_id=%s, base_ui_id=%s, mode=%s",
|
||||||
client_id, base_ui_id, mode
|
client_id, base_ui_id, mode
|
||||||
)
|
)
|
||||||
|
|
||||||
if mode == "local":
|
if mode == ManagerDatabaseSource.local.value:
|
||||||
channel = "local"
|
channel = "local"
|
||||||
else:
|
else:
|
||||||
channel = core.get_config()["channel_url"]
|
channel = core.get_config()["channel_url"]
|
||||||
@ -1586,14 +1564,20 @@ async def queue_start(request):
|
|||||||
@routes.get("/v2/manager/queue/update_comfyui")
|
@routes.get("/v2/manager/queue/update_comfyui")
|
||||||
async def update_comfyui(request):
|
async def update_comfyui(request):
|
||||||
"""Queue a ComfyUI update based on the configured update policy."""
|
"""Queue a ComfyUI update based on the configured update policy."""
|
||||||
# Validate required query parameters
|
try:
|
||||||
validation_error = validate_required_params(request, ["client_id", "ui_id"])
|
# Validate query parameters using Pydantic model
|
||||||
if validation_error:
|
query_params = UpdateComfyUIQueryParams.model_validate(dict(request.rel_url.query))
|
||||||
return validation_error
|
|
||||||
|
# Check if stable parameter was provided, otherwise use config
|
||||||
is_stable = core.get_config()["update_policy"] != "nightly-comfyui"
|
if query_params.stable is None:
|
||||||
client_id = request.rel_url.query["client_id"]
|
is_stable = core.get_config()["update_policy"] != "nightly-comfyui"
|
||||||
ui_id = request.rel_url.query["ui_id"]
|
else:
|
||||||
|
is_stable = query_params.stable
|
||||||
|
|
||||||
|
client_id = query_params.client_id
|
||||||
|
ui_id = query_params.ui_id
|
||||||
|
except ValidationError as e:
|
||||||
|
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
|
||||||
|
|
||||||
# Create update-comfyui task
|
# Create update-comfyui task
|
||||||
task = QueueTaskItem(
|
task = QueueTaskItem(
|
||||||
@ -1625,14 +1609,12 @@ async def comfyui_versions(request):
|
|||||||
@routes.get("/v2/comfyui_manager/comfyui_switch_version")
|
@routes.get("/v2/comfyui_manager/comfyui_switch_version")
|
||||||
async def comfyui_switch_version(request):
|
async def comfyui_switch_version(request):
|
||||||
try:
|
try:
|
||||||
# Validate required query parameters
|
# Validate query parameters using Pydantic model
|
||||||
validation_error = validate_required_params(request, ["ver", "client_id", "ui_id"])
|
query_params = ComfyUISwitchVersionQueryParams.model_validate(dict(request.rel_url.query))
|
||||||
if validation_error:
|
|
||||||
return validation_error
|
target_version = query_params.ver
|
||||||
|
client_id = query_params.client_id
|
||||||
target_version = request.rel_url.query["ver"]
|
ui_id = query_params.ui_id
|
||||||
client_id = request.rel_url.query["client_id"]
|
|
||||||
ui_id = request.rel_url.query["ui_id"]
|
|
||||||
|
|
||||||
# Create update-comfyui task with target version
|
# Create update-comfyui task with target version
|
||||||
task = QueueTaskItem(
|
task = QueueTaskItem(
|
||||||
@ -1644,6 +1626,8 @@ async def comfyui_switch_version(request):
|
|||||||
|
|
||||||
task_queue.put(task)
|
task_queue.put(task)
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
except ValidationError as e:
|
||||||
|
return web.json_response({"error": "Validation error", "details": e.errors()}, status=400)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"ComfyUI version switch fail: {e}", file=sys.stderr)
|
logging.error(f"ComfyUI version switch fail: {e}", file=sys.stderr)
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|||||||
40
openapi.yaml
40
openapi.yaml
@ -378,6 +378,46 @@ components:
|
|||||||
type: string
|
type: string
|
||||||
description: ComfyUI Node Registry ID of the package to enable
|
description: ComfyUI Node Registry ID of the package to enable
|
||||||
required: [cnr_id]
|
required: [cnr_id]
|
||||||
|
# Query Parameter Models
|
||||||
|
UpdateAllQueryParams:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
client_id:
|
||||||
|
type: string
|
||||||
|
description: Client identifier that initiated the request
|
||||||
|
ui_id:
|
||||||
|
type: string
|
||||||
|
description: Base UI identifier for task tracking
|
||||||
|
mode:
|
||||||
|
$ref: '#/components/schemas/ManagerDatabaseSource'
|
||||||
|
required: [client_id, ui_id]
|
||||||
|
UpdateComfyUIQueryParams:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
client_id:
|
||||||
|
type: string
|
||||||
|
description: Client identifier that initiated the request
|
||||||
|
ui_id:
|
||||||
|
type: string
|
||||||
|
description: UI identifier for task tracking
|
||||||
|
stable:
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
description: Whether to update to stable version (true) or nightly (false)
|
||||||
|
required: [client_id, ui_id]
|
||||||
|
ComfyUISwitchVersionQueryParams:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
ver:
|
||||||
|
type: string
|
||||||
|
description: Version to switch to
|
||||||
|
client_id:
|
||||||
|
type: string
|
||||||
|
description: Client identifier that initiated the request
|
||||||
|
ui_id:
|
||||||
|
type: string
|
||||||
|
description: UI identifier for task tracking
|
||||||
|
required: [ver, client_id, ui_id]
|
||||||
# Queue Status Models
|
# Queue Status Models
|
||||||
QueueStatus:
|
QueueStatus:
|
||||||
type: object
|
type: object
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user