convert nodes_hypertile.py to V3 schema (#10061)

This commit is contained in:
Alexander Piskun 2025-09-28 05:16:22 +03:00 committed by GitHub
parent 1cf86f5ae5
commit 2dadb34860
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,11 @@
#Taken from: https://github.com/tfernd/HyperTile/ #Taken from: https://github.com/tfernd/HyperTile/
import math import math
from typing_extensions import override
from einops import rearrange from einops import rearrange
# Use torch rng for consistency across generations # Use torch rng for consistency across generations
from torch import randint from torch import randint
from comfy_api.latest import ComfyExtension, io
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value) min_value = min(min_value, value)
@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
return ns[idx] return ns[idx]
class HyperTile: class HyperTile(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), node_id="HyperTile",
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), category="model_patches/unet",
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), inputs=[
"scale_depth": ("BOOLEAN", {"default": False}), io.Model.Input("model"),
}} io.Int.Input("tile_size", default=256, min=1, max=2048),
RETURN_TYPES = ("MODEL",) io.Int.Input("swap_size", default=2, min=1, max=128),
FUNCTION = "patch" io.Int.Input("max_depth", default=0, min=0, max=10),
io.Boolean.Input("scale_depth", default=False),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
latent_tile_size = max(32, tile_size) // 8 latent_tile_size = max(32, tile_size) // 8
self.temp = None temp = None
def hypertile_in(q, k, v, extra_options): def hypertile_in(q, k, v, extra_options):
nonlocal temp
model_chans = q.shape[-2] model_chans = q.shape[-2]
orig_shape = extra_options['original_shape'] orig_shape = extra_options['original_shape']
apply_to = [] apply_to = []
@ -58,14 +66,15 @@ class HyperTile:
if nh * nw > 1: if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
self.temp = (nh, nw, h, w) temp = (nh, nw, h, w)
return q, k, v return q, k, v
return q, k, v return q, k, v
def hypertile_out(out, extra_options): def hypertile_out(out, extra_options):
if self.temp is not None: nonlocal temp
nh, nw, h, w = self.temp if temp is not None:
self.temp = None nh, nw, h, w = temp
temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out return out
@ -76,6 +85,14 @@ class HyperTile:
m.set_model_attn1_output_patch(hypertile_out) m.set_model_attn1_output_patch(hypertile_out)
return (m, ) return (m, )
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile, class HyperTileExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
HyperTile,
]
async def comfy_entrypoint() -> HyperTileExtension:
return HyperTileExtension()