mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 13:24:44 +08:00
880 lines
42 KiB
Python
880 lines
42 KiB
Python
from comfy.ldm.modules import attention as comfy_attention
|
|
import logging
|
|
import comfy.model_patcher
|
|
import comfy.utils
|
|
import comfy.sd
|
|
import torch
|
|
import folder_paths
|
|
import comfy.model_management as mm
|
|
from comfy.cli_args import args
|
|
|
|
orig_attention = comfy_attention.optimized_attention
|
|
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
|
|
original_load_lora_for_models = comfy.sd.load_lora_for_models
|
|
|
|
class BaseLoaderKJ:
|
|
original_linear = None
|
|
cublas_patched = False
|
|
|
|
def _patch_modules(self, patch_cublaslinear, sage_attention):
|
|
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
|
|
|
if sage_attention != "disabled":
|
|
print("Patching comfy attention to use sageattn")
|
|
from sageattention import sageattn
|
|
def set_sage_func(sage_attention):
|
|
if sage_attention == "auto":
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
|
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
|
return func
|
|
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
|
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
|
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
|
|
return func
|
|
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
|
|
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
|
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
|
return func
|
|
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
|
|
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
|
|
return func
|
|
|
|
sage_func = set_sage_func(sage_attention)
|
|
|
|
@torch.compiler.disable()
|
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
|
if skip_reshape:
|
|
b, _, _, dim_head = q.shape
|
|
tensor_layout="HND"
|
|
else:
|
|
b, _, dim_head = q.shape
|
|
dim_head //= heads
|
|
q, k, v = map(
|
|
lambda t: t.view(b, -1, heads, dim_head),
|
|
(q, k, v),
|
|
)
|
|
tensor_layout="NHD"
|
|
if mask is not None:
|
|
# add a batch dimension if there isn't already one
|
|
if mask.ndim == 2:
|
|
mask = mask.unsqueeze(0)
|
|
# add a heads dimension if there isn't already one
|
|
if mask.ndim == 3:
|
|
mask = mask.unsqueeze(1)
|
|
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
|
if tensor_layout == "HND":
|
|
if not skip_output_reshape:
|
|
out = (
|
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
|
)
|
|
else:
|
|
if skip_output_reshape:
|
|
out = out.transpose(1, 2)
|
|
else:
|
|
out = out.reshape(b, -1, heads * dim_head)
|
|
return out
|
|
|
|
comfy_attention.optimized_attention = attention_sage
|
|
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
|
|
comfy.ldm.flux.math.optimized_attention = attention_sage
|
|
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
|
|
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
|
|
comfy.ldm.wan.model.optimized_attention = attention_sage
|
|
|
|
else:
|
|
comfy_attention.optimized_attention = orig_attention
|
|
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
|
|
comfy.ldm.flux.math.optimized_attention = orig_attention
|
|
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
|
|
comfy.ldm.cosmos.blocks.optimized_attention = orig_attention
|
|
comfy.ldm.wan.model.optimized_attention = orig_attention
|
|
|
|
if patch_cublaslinear:
|
|
if not BaseLoaderKJ.cublas_patched:
|
|
BaseLoaderKJ.original_linear = disable_weight_init.Linear
|
|
try:
|
|
from cublas_ops import CublasLinear
|
|
except ImportError:
|
|
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
|
|
|
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
pass
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
if self.comfy_cast_weights:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
disable_weight_init.Linear = PatchedLinear
|
|
BaseLoaderKJ.cublas_patched = True
|
|
else:
|
|
if BaseLoaderKJ.cublas_patched:
|
|
disable_weight_init.Linear = BaseLoaderKJ.original_linear
|
|
BaseLoaderKJ.cublas_patched = False
|
|
|
|
class PathchSageAttentionKJ(BaseLoaderKJ):
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"model": ("MODEL",),
|
|
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
|
|
}}
|
|
|
|
RETURN_TYPES = ("MODEL", )
|
|
FUNCTION = "patch"
|
|
DESCRIPTION = "Experimental node for patching attention mode. This doesn't use the model patching system and thus can't be disabled without running the node again with 'disabled' option."
|
|
EXPERIMENTAL = True
|
|
CATEGORY = "KJNodes/experimental"
|
|
|
|
def patch(self, model, sage_attention):
|
|
self._patch_modules(False, sage_attention)
|
|
return model,
|
|
|
|
class CheckpointLoaderKJ(BaseLoaderKJ):
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
|
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
|
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
|
}}
|
|
|
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
|
FUNCTION = "patch"
|
|
OUTPUT_NODE = True
|
|
DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear."
|
|
EXPERIMENTAL = True
|
|
CATEGORY = "KJNodes/experimental"
|
|
|
|
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
|
from nodes import CheckpointLoaderSimple
|
|
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
|
return model, clip, vae
|
|
|
|
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
|
|
"compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "fp16", "tooltip": "The compute dtype to use for the model."}),
|
|
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
|
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
|
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
|
}}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch_and_load"
|
|
OUTPUT_NODE = True
|
|
DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear."
|
|
EXPERIMENTAL = True
|
|
CATEGORY = "KJNodes/experimental"
|
|
|
|
def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
|
|
DTYPE_MAP = {
|
|
"fp8_e4m3fn": torch.float8_e4m3fn,
|
|
"fp8_e5m2": torch.float8_e5m2,
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
"fp32": torch.float32
|
|
}
|
|
model_options = {}
|
|
if dtype := DTYPE_MAP.get(weight_dtype):
|
|
model_options["dtype"] = dtype
|
|
print(f"Setting {model_name} weight dtype to {dtype}")
|
|
|
|
if weight_dtype == "fp8_e4m3fn_fast":
|
|
model_options["dtype"] = torch.float8_e4m3fn
|
|
model_options["fp8_optimizations"] = True
|
|
|
|
if enable_fp16_accumulation:
|
|
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
|
else:
|
|
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
|
|
else:
|
|
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = False
|
|
|
|
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
|
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
|
if dtype := DTYPE_MAP.get(compute_dtype):
|
|
model.set_model_compute_dtype(dtype)
|
|
model.force_cast_weights = False
|
|
print(f"Setting {model_name} compute dtype to {dtype}")
|
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
|
|
|
return (model,)
|
|
|
|
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
|
|
|
if lowvram_model_memory == 0:
|
|
full_load = True
|
|
else:
|
|
full_load = False
|
|
|
|
device_to = mm.get_torch_device()
|
|
load_weights = True
|
|
if load_weights:
|
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
|
for k in self.object_patches:
|
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
|
if k not in self.object_patches_backup:
|
|
self.object_patches_backup[k] = old
|
|
|
|
return self.model
|
|
|
|
def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|
|
|
patch_keys = list(model.object_patches_backup.keys())
|
|
for k in patch_keys:
|
|
#print("backing up object patch: ", k)
|
|
comfy.utils.set_attr(model.model, k, model.object_patches_backup[k])
|
|
|
|
key_map = {}
|
|
if model is not None:
|
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
|
if clip is not None:
|
|
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
|
|
|
loaded = comfy.lora.load_lora(lora, key_map)
|
|
#print(temp_object_patches_backup)
|
|
|
|
if model is not None:
|
|
new_modelpatcher = model.clone()
|
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
|
else:
|
|
k = ()
|
|
new_modelpatcher = None
|
|
|
|
if clip is not None:
|
|
new_clip = clip.clone()
|
|
k1 = new_clip.add_patches(loaded, strength_clip)
|
|
else:
|
|
k1 = ()
|
|
new_clip = None
|
|
k = set(k)
|
|
k1 = set(k1)
|
|
for x in loaded:
|
|
if (x not in k) and (x not in k1):
|
|
print("NOT LOADED {}".format(x))
|
|
|
|
if patch_keys:
|
|
if hasattr(model.model, "compile_settings"):
|
|
compile_settings = getattr(model.model, "compile_settings")
|
|
print("compile_settings: ", compile_settings)
|
|
for k in patch_keys:
|
|
if "diffusion_model." in k:
|
|
# Remove the prefix to get the attribute path
|
|
key = k.replace('diffusion_model.', '')
|
|
attributes = key.split('.')
|
|
# Start with the diffusion_model object
|
|
block = model.get_model_object("diffusion_model")
|
|
# Navigate through the attributes to get to the block
|
|
for attr in attributes:
|
|
if attr.isdigit():
|
|
block = block[int(attr)]
|
|
else:
|
|
block = getattr(block, attr)
|
|
# Compile the block
|
|
compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
|
|
# Add the compiled block back as an object patch
|
|
model.add_object_patch(k, compiled_block)
|
|
return (new_modelpatcher, new_clip)
|
|
|
|
class PatchModelPatcherOrder:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"model": ("MODEL",),
|
|
"patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
CATEGORY = "KJNodes/experimental"
|
|
DESCRIPTION = "Patch the comfy patch_model function patching order, useful for torch.compile (used as object_patch) as it should come last if you want to use LoRAs with compile"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch(self, model, patch_order):
|
|
comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {}
|
|
if patch_order == "weight_patch_first":
|
|
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
|
|
comfy.sd.load_lora_for_models = patched_load_lora_for_models
|
|
else:
|
|
comfy.model_patcher.ModelPatcher.patch_model = original_patch_model
|
|
comfy.sd.load_lora_for_models = original_load_lora_for_models
|
|
|
|
return model,
|
|
|
|
class TorchCompileModelFluxAdvanced:
|
|
def __init__(self):
|
|
self._compiled = False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"model": ("MODEL",),
|
|
"backend": (["inductor", "cudagraphs"],),
|
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
|
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
|
|
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
|
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "KJNodes/torchcompile"
|
|
EXPERIMENTAL = True
|
|
|
|
def parse_blocks(self, blocks_str):
|
|
blocks = []
|
|
for part in blocks_str.split(','):
|
|
part = part.strip()
|
|
if '-' in part:
|
|
start, end = map(int, part.split('-'))
|
|
blocks.extend(range(start, end + 1))
|
|
else:
|
|
blocks.append(int(part))
|
|
return blocks
|
|
|
|
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic):
|
|
single_block_list = self.parse_blocks(single_blocks)
|
|
double_block_list = self.parse_blocks(double_blocks)
|
|
m = model.clone()
|
|
diffusion_model = m.get_model_object("diffusion_model")
|
|
|
|
if not self._compiled:
|
|
try:
|
|
for i, block in enumerate(diffusion_model.double_blocks):
|
|
if i in double_block_list:
|
|
#print("Compiling double_block", i)
|
|
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
|
for i, block in enumerate(diffusion_model.single_blocks):
|
|
if i in single_block_list:
|
|
#print("Compiling single block", i)
|
|
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
|
self._compiled = True
|
|
compile_settings = {
|
|
"backend": backend,
|
|
"mode": mode,
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
}
|
|
setattr(m.model, "compile_settings", compile_settings)
|
|
except:
|
|
raise RuntimeError("Failed to compile model")
|
|
|
|
return (m, )
|
|
# rest of the layers that are not patched
|
|
# diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
# diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
# diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
# diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
|
|
class TorchCompileModelHyVideo:
|
|
def __init__(self):
|
|
self._compiled = False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
|
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
|
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
|
"compile_single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
|
|
"compile_double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
|
|
"compile_txt_in": ("BOOLEAN", {"default": False, "tooltip": "Compile txt_in layers"}),
|
|
"compile_vector_in": ("BOOLEAN", {"default": False, "tooltip": "Compile vector_in layers"}),
|
|
"compile_final_layer": ("BOOLEAN", {"default": False, "tooltip": "Compile final layer"}),
|
|
|
|
},
|
|
}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "KJNodes/torchcompile"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks, compile_txt_in, compile_vector_in, compile_final_layer):
|
|
m = model.clone()
|
|
diffusion_model = m.get_model_object("diffusion_model")
|
|
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
|
|
if not self._compiled:
|
|
try:
|
|
if compile_single_blocks:
|
|
for i, block in enumerate(diffusion_model.single_blocks):
|
|
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
|
m.add_object_patch(f"diffusion_model.single_blocks.{i}", compiled_block)
|
|
if compile_double_blocks:
|
|
for i, block in enumerate(diffusion_model.double_blocks):
|
|
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
|
m.add_object_patch(f"diffusion_model.double_blocks.{i}", compiled_block)
|
|
if compile_txt_in:
|
|
compiled_block = torch.compile(diffusion_model.txt_in, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
|
m.add_object_patch("diffusion_model.txt_in", compiled_block)
|
|
if compile_vector_in:
|
|
compiled_block = torch.compile(diffusion_model.vector_in, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
|
m.add_object_patch("diffusion_model.vector_in", compiled_block)
|
|
if compile_final_layer:
|
|
compiled_block = torch.compile(diffusion_model.final_layer, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
|
m.add_object_patch("diffusion_model.final_layer", compiled_block)
|
|
self._compiled = True
|
|
compile_settings = {
|
|
"backend": backend,
|
|
"mode": mode,
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
}
|
|
setattr(m.model, "compile_settings", compile_settings)
|
|
except:
|
|
raise RuntimeError("Failed to compile model")
|
|
return (m, )
|
|
|
|
class TorchCompileModelWanVideo:
|
|
def __init__(self):
|
|
self._compiled = False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
|
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
|
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
|
"compile_transformer_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile all transformer blocks"}),
|
|
},
|
|
}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "KJNodes/torchcompile"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks):
|
|
m = model.clone()
|
|
diffusion_model = m.get_model_object("diffusion_model")
|
|
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
|
|
if not self._compiled:
|
|
try:
|
|
if compile_transformer_blocks:
|
|
for i, block in enumerate(diffusion_model.blocks):
|
|
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
|
m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
|
|
self._compiled = True
|
|
compile_settings = {
|
|
"backend": backend,
|
|
"mode": mode,
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
}
|
|
setattr(m.model, "compile_settings", compile_settings)
|
|
except:
|
|
raise RuntimeError("Failed to compile model")
|
|
return (m, )
|
|
|
|
class TorchCompileVAE:
|
|
def __init__(self):
|
|
self._compiled_encoder = False
|
|
self._compiled_decoder = False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"vae": ("VAE",),
|
|
"backend": (["inductor", "cudagraphs"],),
|
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
|
"compile_encoder": ("BOOLEAN", {"default": True, "tooltip": "Compile encoder"}),
|
|
"compile_decoder": ("BOOLEAN", {"default": True, "tooltip": "Compile decoder"}),
|
|
}}
|
|
RETURN_TYPES = ("VAE",)
|
|
FUNCTION = "compile"
|
|
|
|
CATEGORY = "KJNodes/torchcompile"
|
|
EXPERIMENTAL = True
|
|
|
|
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
|
|
if compile_encoder:
|
|
if not self._compiled_encoder:
|
|
encoder_name = "encoder"
|
|
if hasattr(vae.first_stage_model, "taesd_encoder"):
|
|
encoder_name = "taesd_encoder"
|
|
|
|
try:
|
|
setattr(
|
|
vae.first_stage_model,
|
|
encoder_name,
|
|
torch.compile(
|
|
getattr(vae.first_stage_model, encoder_name),
|
|
mode=mode,
|
|
fullgraph=fullgraph,
|
|
backend=backend,
|
|
),
|
|
)
|
|
self._compiled_encoder = True
|
|
except:
|
|
raise RuntimeError("Failed to compile model")
|
|
if compile_decoder:
|
|
if not self._compiled_decoder:
|
|
decoder_name = "decoder"
|
|
if hasattr(vae.first_stage_model, "taesd_decoder"):
|
|
decoder_name = "taesd_decoder"
|
|
|
|
try:
|
|
setattr(
|
|
vae.first_stage_model,
|
|
decoder_name,
|
|
torch.compile(
|
|
getattr(vae.first_stage_model, decoder_name),
|
|
mode=mode,
|
|
fullgraph=fullgraph,
|
|
backend=backend,
|
|
),
|
|
)
|
|
self._compiled_decoder = True
|
|
except:
|
|
raise RuntimeError("Failed to compile model")
|
|
return (vae, )
|
|
|
|
class TorchCompileControlNet:
|
|
def __init__(self):
|
|
self._compiled= False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"controlnet": ("CONTROL_NET",),
|
|
"backend": (["inductor", "cudagraphs"],),
|
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
|
}}
|
|
RETURN_TYPES = ("CONTROL_NET",)
|
|
FUNCTION = "compile"
|
|
|
|
CATEGORY = "KJNodes/torchcompile"
|
|
EXPERIMENTAL = True
|
|
|
|
def compile(self, controlnet, backend, mode, fullgraph):
|
|
if not self._compiled:
|
|
try:
|
|
# for i, block in enumerate(controlnet.control_model.double_blocks):
|
|
# print("Compiling controlnet double_block", i)
|
|
# controlnet.control_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
controlnet.control_model = torch.compile(controlnet.control_model, mode=mode, fullgraph=fullgraph, backend=backend)
|
|
self._compiled = True
|
|
except:
|
|
self._compiled = False
|
|
raise RuntimeError("Failed to compile model")
|
|
|
|
return (controlnet, )
|
|
|
|
class TorchCompileLTXModel:
|
|
def __init__(self):
|
|
self._compiled = False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"model": ("MODEL",),
|
|
"backend": (["inductor", "cudagraphs"],),
|
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "KJNodes/torchcompile"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch(self, model, backend, mode, fullgraph, dynamic):
|
|
m = model.clone()
|
|
diffusion_model = m.get_model_object("diffusion_model")
|
|
|
|
if not self._compiled:
|
|
try:
|
|
for i, block in enumerate(diffusion_model.transformer_blocks):
|
|
compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
|
|
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block)
|
|
self._compiled = True
|
|
compile_settings = {
|
|
"backend": backend,
|
|
"mode": mode,
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
}
|
|
setattr(m.model, "compile_settings", compile_settings)
|
|
|
|
except:
|
|
raise RuntimeError("Failed to compile model")
|
|
|
|
return (m, )
|
|
|
|
class TorchCompileCosmosModel:
|
|
def __init__(self):
|
|
self._compiled = False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"model": ("MODEL",),
|
|
"backend": (["inductor", "cudagraphs"],),
|
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
|
"dynamo_cache_size_limit": ("INT", {"default": 64, "tooltip": "Set the dynamo cache size limit"}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "KJNodes/torchcompile"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch(self, model, backend, mode, fullgraph, dynamic, dynamo_cache_size_limit):
|
|
|
|
m = model.clone()
|
|
diffusion_model = m.get_model_object("diffusion_model")
|
|
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
|
|
|
|
if not self._compiled:
|
|
try:
|
|
for name, block in diffusion_model.blocks.items():
|
|
#print(f"Compiling block {name}")
|
|
compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
|
|
m.add_object_patch(f"diffusion_model.blocks.{name}", compiled_block)
|
|
#diffusion_model.blocks[name] = compiled_block
|
|
|
|
self._compiled = True
|
|
compile_settings = {
|
|
"backend": backend,
|
|
"mode": mode,
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
}
|
|
setattr(m.model, "compile_settings", compile_settings)
|
|
|
|
except:
|
|
raise RuntimeError("Failed to compile model")
|
|
|
|
return (m, )
|
|
|
|
|
|
#teacache
|
|
|
|
from comfy.ldm.wan.model import sinusoidal_embedding_1d
|
|
from einops import repeat
|
|
from unittest.mock import patch
|
|
from contextlib import nullcontext
|
|
|
|
def relative_l1_distance(last_tensor, current_tensor):
|
|
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
|
norm = torch.abs(last_tensor).mean()
|
|
relative_l1_distance = l1_distance / norm
|
|
return relative_l1_distance.to(torch.float32)
|
|
|
|
#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently
|
|
def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
|
bs, c, t, h, w = x.shape
|
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
|
patch_size = self.patch_size
|
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
|
|
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w]
|
|
|
|
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs):
|
|
# embeddings
|
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
|
grid_sizes = x.shape[2:]
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
# time embeddings
|
|
e = self.time_embedding(
|
|
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
|
|
|
# context
|
|
context = self.text_embedding(context)
|
|
if clip_fea is not None and self.img_emb is not None:
|
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
|
context = torch.concat([context_clip, context], dim=1)
|
|
|
|
#teacache for cond and uncond separately
|
|
rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"]
|
|
cache_device = kwargs["transformer_options"]["teacache_device"]
|
|
is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False
|
|
|
|
should_calc = True
|
|
suffix = "cond" if is_cond else "uncond"
|
|
|
|
# Init cache dict if not exists
|
|
if not hasattr(self, 'teacache_state'):
|
|
self.teacache_state = {
|
|
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
|
'teacache_skipped_steps': 0, 'previous_residual': None},
|
|
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
|
'teacache_skipped_steps': 0, 'previous_residual': None}
|
|
}
|
|
logging.info("TeaCache: Initialized")
|
|
|
|
cache = self.teacache_state[suffix]
|
|
|
|
if cache['prev_input'] is not None:
|
|
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
|
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
|
try:
|
|
if curr_acc_dist < rel_l1_thresh:
|
|
should_calc = False
|
|
cache['accumulated_rel_l1_distance'] = curr_acc_dist
|
|
else:
|
|
should_calc = True
|
|
cache['accumulated_rel_l1_distance'] = 0
|
|
except:
|
|
should_calc = True
|
|
cache['accumulated_rel_l1_distance'] = 0
|
|
|
|
cache['prev_input'] = e0.clone().detach()
|
|
|
|
if not should_calc:
|
|
x += cache['previous_residual'].to(x.device)
|
|
cache['teacache_skipped_steps'] += 1
|
|
#print(f"TeaCache: Skipping {suffix} step")
|
|
|
|
if should_calc:
|
|
original_x = x.clone().detach()
|
|
# arguments
|
|
block_wargs = dict(
|
|
e=e0,
|
|
freqs=freqs,
|
|
context=context)
|
|
|
|
for block in self.blocks:
|
|
x = block(x, **block_wargs)
|
|
|
|
cache['previous_residual'] = (x - original_x).to(cache_device)
|
|
|
|
# head
|
|
x = self.head(x, e)
|
|
|
|
# unpatchify
|
|
x = self.unpatchify(x, grid_sizes)
|
|
return x
|
|
|
|
class WanVideoTeaCacheKJ:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"rel_l1_thresh": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy"}),
|
|
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
|
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
|
|
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
RETURN_NAMES = ("model",)
|
|
FUNCTION = "patch_teacache"
|
|
CATEGORY = "KJNodes/teacache"
|
|
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device):
|
|
if rel_l1_thresh == 0:
|
|
return (model,)
|
|
|
|
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
|
|
|
|
model_clone = model.clone()
|
|
if 'transformer_options' not in model_clone.model_options:
|
|
model_clone.model_options['transformer_options'] = {}
|
|
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
|
|
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
|
|
diffusion_model = model_clone.get_model_object("diffusion_model")
|
|
|
|
def outer_wrapper(start_percent, end_percent):
|
|
def unet_wrapper_function(model_function, kwargs):
|
|
input = kwargs["input"]
|
|
timestep = kwargs["timestep"]
|
|
c = kwargs["c"]
|
|
sigmas = c["transformer_options"]["sample_sigmas"]
|
|
cond_or_uncond = kwargs["cond_or_uncond"]
|
|
last_step = (len(sigmas) - 1)
|
|
|
|
matched_step_index = (sigmas == timestep[0] ).nonzero()
|
|
if len(matched_step_index) > 0:
|
|
current_step_index = matched_step_index.item()
|
|
else:
|
|
for i in range(len(sigmas) - 1):
|
|
# walk from beginning of steps until crossing the timestep
|
|
if (sigmas[i] - timestep) * (sigmas[i + 1] - timestep) <= 0:
|
|
current_step_index = i
|
|
break
|
|
else:
|
|
current_step_index = 0
|
|
|
|
if current_step_index == 0:
|
|
if hasattr(diffusion_model, "teacache_state"):
|
|
delattr(diffusion_model, "teacache_state")
|
|
logging.info("Resetting TeaCache state")
|
|
|
|
current_percent = current_step_index / (len(sigmas) - 1)
|
|
if start_percent <= current_percent <= end_percent:
|
|
c["transformer_options"]["teacache_enabled"] = True
|
|
|
|
context = patch.multiple(
|
|
diffusion_model,
|
|
forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__),
|
|
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
|
)
|
|
else:
|
|
context = nullcontext()
|
|
with context:
|
|
out = model_function(input, timestep, **c)
|
|
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
|
|
if cond_or_uncond[0] == 0:
|
|
skipped_steps_cond = diffusion_model.teacache_state["cond"]["teacache_skipped_steps"]
|
|
skipped_steps_uncond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"]
|
|
|
|
logging.info("-----------------------------------")
|
|
logging.info(f"TeaCache skipped:")
|
|
logging.info(f"{skipped_steps_cond} cond steps")
|
|
logging.info(f"{skipped_steps_uncond} uncond step")
|
|
logging.info(f"out of {last_step} steps")
|
|
logging.info("-----------------------------------")
|
|
return out
|
|
return unet_wrapper_function
|
|
|
|
model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent))
|
|
|
|
return (model_clone,)
|