convert nodes_morphology.py to V3 schema (#10159)

This commit is contained in:
Alexander Piskun 2025-10-02 23:53:00 +03:00 committed by GitHub
parent 0e9d1724be
commit 8f4ee9984c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,24 +1,34 @@
import torch import torch
import comfy.model_management import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color import kornia.color
class Morphology: class Morphology(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE",), return io.Schema(
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), node_id="Morphology",
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), display_name="ImageMorphology",
}} category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Combo.Input(
"operation",
options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],
),
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "process" def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
CATEGORY = "image/postprocessing"
def process(self, image, operation, kernel_size):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device) kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1) image_k = image.to(device).movedim(-1, 1)
@ -39,49 +49,63 @@ class Morphology:
else: else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,) return io.NodeOutput(img_out)
class ImageRGBToYUV: class ImageRGBToYUV(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return io.Schema(
}} node_id="ImageRGBToYUV",
category="image/batch",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(display_name="Y"),
io.Image.Output(display_name="U"),
io.Image.Output(display_name="V"),
],
)
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") @classmethod
RETURN_NAMES = ("Y", "U", "V") def execute(cls, image) -> io.NodeOutput:
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB: class ImageYUVToRGB(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"Y": ("IMAGE",), return io.Schema(
"U": ("IMAGE",), node_id="ImageYUVToRGB",
"V": ("IMAGE",), category="image/batch",
}} inputs=[
io.Image.Input("Y"),
io.Image.Input("U"),
io.Image.Input("V"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "execute" def execute(cls, Y, U, V) -> io.NodeOutput:
CATEGORY = "image/batch"
def execute(self, Y, U, V):
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return (out,) return io.NodeOutput(out)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
}
NODE_DISPLAY_NAME_MAPPINGS = { class MorphologyExtension(ComfyExtension):
"Morphology": "ImageMorphology", @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Morphology,
ImageRGBToYUV,
ImageYUVToRGB,
]
async def comfy_entrypoint() -> MorphologyExtension:
return MorphologyExtension()