mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-15 00:44:25 +08:00
convert nodes_hypertile.py to V3 schema (#10061)
This commit is contained in:
parent
1cf86f5ae5
commit
2dadb34860
@ -1,9 +1,11 @@
|
|||||||
#Taken from: https://github.com/tfernd/HyperTile/
|
#Taken from: https://github.com/tfernd/HyperTile/
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing_extensions import override
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
# Use torch rng for consistency across generations
|
# Use torch rng for consistency across generations
|
||||||
from torch import randint
|
from torch import randint
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
min_value = min(min_value, value)
|
min_value = min(min_value, value)
|
||||||
@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
|||||||
|
|
||||||
return ns[idx]
|
return ns[idx]
|
||||||
|
|
||||||
class HyperTile:
|
class HyperTile(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
|
node_id="HyperTile",
|
||||||
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
|
category="model_patches/unet",
|
||||||
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
|
inputs=[
|
||||||
"scale_depth": ("BOOLEAN", {"default": False}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.Int.Input("tile_size", default=256, min=1, max=2048),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Int.Input("swap_size", default=2, min=1, max=128),
|
||||||
FUNCTION = "patch"
|
io.Int.Input("max_depth", default=0, min=0, max=10),
|
||||||
|
io.Boolean.Input("scale_depth", default=False),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
|
||||||
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
|
|
||||||
latent_tile_size = max(32, tile_size) // 8
|
latent_tile_size = max(32, tile_size) // 8
|
||||||
self.temp = None
|
temp = None
|
||||||
|
|
||||||
def hypertile_in(q, k, v, extra_options):
|
def hypertile_in(q, k, v, extra_options):
|
||||||
|
nonlocal temp
|
||||||
model_chans = q.shape[-2]
|
model_chans = q.shape[-2]
|
||||||
orig_shape = extra_options['original_shape']
|
orig_shape = extra_options['original_shape']
|
||||||
apply_to = []
|
apply_to = []
|
||||||
@ -58,14 +66,15 @@ class HyperTile:
|
|||||||
|
|
||||||
if nh * nw > 1:
|
if nh * nw > 1:
|
||||||
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
self.temp = (nh, nw, h, w)
|
temp = (nh, nw, h, w)
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
def hypertile_out(out, extra_options):
|
def hypertile_out(out, extra_options):
|
||||||
if self.temp is not None:
|
nonlocal temp
|
||||||
nh, nw, h, w = self.temp
|
if temp is not None:
|
||||||
self.temp = None
|
nh, nw, h, w = temp
|
||||||
|
temp = None
|
||||||
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||||
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||||
return out
|
return out
|
||||||
@ -76,6 +85,14 @@ class HyperTile:
|
|||||||
m.set_model_attn1_output_patch(hypertile_out)
|
m.set_model_attn1_output_patch(hypertile_out)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"HyperTile": HyperTile,
|
class HyperTileExtension(ComfyExtension):
|
||||||
}
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
HyperTile,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HyperTileExtension:
|
||||||
|
return HyperTileExtension()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user