convert nodes_compositing.py to V3 schema (#10174)

This commit is contained in:
Alexander Piskun 2025-10-09 09:13:15 +03:00 committed by GitHub
parent 989f715d92
commit 6732014a0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,9 @@
import torch import torch
import comfy.utils import comfy.utils
from enum import Enum from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def resize_mask(mask, shape): def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha return out_image, out_alpha
class PorterDuffImageComposite: class PorterDuffImageComposite(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="PorterDuffImageComposite",
"source": ("IMAGE",), display_name="Porter-Duff Image Composite",
"source_alpha": ("MASK",), category="mask/compositing",
"destination": ("IMAGE",), inputs=[
"destination_alpha": ("MASK",), io.Image.Input("source"),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), io.Mask.Input("source_alpha"),
}, io.Image.Input("destination"),
} io.Mask.Input("destination_alpha"),
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
RETURN_TYPES = ("IMAGE", "MASK") @classmethod
FUNCTION = "composite" def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = [] out_images = []
out_alphas = [] out_alphas = []
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
out_images.append(out_image) out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2)) out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas)) return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
return result
class SplitImageWithAlpha: class SplitImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="SplitImageWithAlpha",
"image": ("IMAGE",), display_name="Split Image with Alpha",
} category="mask/compositing",
} inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
CATEGORY = "mask/compositing" @classmethod
RETURN_TYPES = ("IMAGE", "MASK") def execute(cls, image: torch.Tensor) -> io.NodeOutput:
FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image] out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
class JoinImageWithAlpha: class JoinImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="JoinImageWithAlpha",
"image": ("IMAGE",), display_name="Join Image with Alpha",
"alpha": ("MASK",), category="mask/compositing",
} inputs=[
} io.Image.Input("image"),
io.Mask.Input("alpha"),
],
outputs=[io.Image.Output()],
)
CATEGORY = "mask/compositing" @classmethod
RETURN_TYPES = ("IMAGE",) def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
batch_size = min(len(image), len(alpha)) batch_size = min(len(image), len(alpha))
out_images = [] out_images = []
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
for i in range(batch_size): for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),) return io.NodeOutput(torch.stack(out_images))
return result
NODE_CLASS_MAPPINGS = { class CompositingExtension(ComfyExtension):
"PorterDuffImageComposite": PorterDuffImageComposite, @override
"SplitImageWithAlpha": SplitImageWithAlpha, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"JoinImageWithAlpha": JoinImageWithAlpha, return [
} PorterDuffImageComposite,
SplitImageWithAlpha,
JoinImageWithAlpha,
]
NODE_DISPLAY_NAME_MAPPINGS = { async def comfy_entrypoint() -> CompositingExtension:
"PorterDuffImageComposite": "Porter-Duff Image Composite", return CompositingExtension()
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}