import os from comfy.ldm.modules import attention as comfy_attention import logging import torch import importlib import math import datetime import folder_paths import comfy.model_management as mm from comfy.cli_args import args from comfy.ldm.modules.attention import wrap_attn, optimized_attention import comfy.model_patcher import comfy.utils import comfy.sd try: from comfy_api.latest import io v3_available = True except ImportError: v3_available = False logging.warning("ComfyUI v3 node API not available, please update ComfyUI to access latest v3 nodes.") sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++", "sageattn3", "sageattn3_per_block_mean"] _initialized = False _original_functions = {} if not _initialized: _original_functions["orig_attention"] = comfy_attention.optimized_attention _original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model _original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models try: _original_functions["original_qwen_forward"] = comfy.ldm.qwen_image.model.Attention.forward except: pass _initialized = True def get_sage_func(sage_attention, allow_compile=False): logging.info(f"Using sage attention mode: {sage_attention}") from sageattention import sageattn if sage_attention == "auto": def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda": from sageattention import sageattn_qk_int8_pv_fp16_cuda def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout) elif sage_attention == "sageattn_qk_int8_pv_fp16_triton": from sageattention import sageattn_qk_int8_pv_fp16_triton def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda": from sageattention import sageattn_qk_int8_pv_fp8_cuda def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout) elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++": from sageattention import sageattn_qk_int8_pv_fp8_cuda def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout) elif "sageattn3" in sage_attention: from sageattn3 import sageattn3_blackwell if sage_attention == "sageattn3_per_block_mean": def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs): return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2) else: def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs): return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2) if not allow_compile: sage_func = torch.compiler.disable()(sage_func) @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): in_dtype = v.dtype if q.dtype == torch.float32 or k.dtype == torch.float32 or v.dtype == torch.float32: q, k, v = q.to(torch.float16), k.to(torch.float16), v.to(torch.float16) if skip_reshape: b, _, _, dim_head = q.shape tensor_layout="HND" else: b, _, dim_head = q.shape dim_head //= heads q, k, v = map( lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) tensor_layout="NHD" if mask is not None: # add a batch dimension if there isn't already one if mask.ndim == 2: mask = mask.unsqueeze(0) # add a heads dimension if there isn't already one if mask.ndim == 3: mask = mask.unsqueeze(1) out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout).to(in_dtype) if tensor_layout == "HND": if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) else: if skip_output_reshape: out = out.transpose(1, 2) else: out = out.reshape(b, -1, heads * dim_head) return out return attention_sage class BaseLoaderKJ: original_linear = None cublas_patched = False def _patch_modules(self, patch_cublaslinear, sage_attention): from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight if patch_cublaslinear: if not BaseLoaderKJ.cublas_patched: BaseLoaderKJ.original_linear = disable_weight_init.Linear try: from cublas_ops import CublasLinear except ImportError: raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") class PatchedLinear(CublasLinear, CastWeightBiasOp): def reset_parameters(self): pass def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): if self.comfy_cast_weights: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) disable_weight_init.Linear = PatchedLinear BaseLoaderKJ.cublas_patched = True else: if BaseLoaderKJ.cublas_patched: disable_weight_init.Linear = BaseLoaderKJ.original_linear BaseLoaderKJ.cublas_patched = False from comfy.patcher_extension import CallbacksMP class PathchSageAttentionKJ(): @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}), }, "optional": { "allow_compile": ("BOOLEAN", {"default": False, "tooltip": "Allow the use of torch.compile for the sage attention function, requires latest sageattn 2.2.0 or higher."}) } } RETURN_TYPES = ("MODEL", ) FUNCTION = "patch" DESCRIPTION = "Experimental node for patching attention mode. This doesn't use the model patching system and thus can't be disabled without running the node again with 'disabled' option." EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" def patch(self, model, sage_attention, allow_compile=False): if sage_attention == "disabled": return model, model_clone = model.clone() new_attention = get_sage_func(sage_attention, allow_compile=allow_compile) def attention_override_sage(func, *args, **kwargs): return new_attention.__wrapped__(*args, **kwargs) # attention override model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage return model_clone, class CheckpointLoaderKJ(BaseLoaderKJ): @classmethod def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],), "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}), "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, required minimum pytorch version 2.7.1"}), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "patch" DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear." EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" def patch(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation): DTYPE_MAP = { "fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e5m2": torch.float8_e5m2, "fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32 } model_options = {} if dtype := DTYPE_MAP.get(weight_dtype): model_options["dtype"] = dtype logging.info(f"Setting {ckpt_name} weight dtype to {dtype}") if weight_dtype == "fp8_e4m3fn_fast": model_options["dtype"] = torch.float8_e4m3fn model_options["fp8_optimizations"] = True ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) model, clip, vae = self.load_state_dict_guess_config( sd, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=metadata, model_options=model_options) if dtype := DTYPE_MAP.get(compute_dtype): model.set_model_compute_dtype(dtype) model.force_cast_weights = False logging.info(f"Setting {ckpt_name} compute dtype to {dtype}") if enable_fp16_accumulation: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = True else: raise RuntimeError("Failed to set fp16 accumulation, requires pytorch version 2.7.1 or higher") else: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = False if sage_attention != "disabled": new_attention = get_sage_func(sage_attention) def attention_override_sage(func, *args, **kwargs): return new_attention.__wrapped__(*args, **kwargs) # attention override model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage return model, clip, vae def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): from comfy.sd import load_diffusion_model_state_dict, model_detection, VAE, CLIP clip = None vae = None model = None model_patcher = None diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = mm.get_torch_device() model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") diffusion_model = load_diffusion_model_state_dict(sd, model_options={}) if diffusion_model is None: return None return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None model_config.custom_operations = model_options.get("custom_operations", None) unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) manual_cast_dtype = mm.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if output_model: inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model.load_model_weights(sd, diffusion_model_prefix) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) m, u = clip.load_sd(clip_sd, full_model=True) if len(m) > 0: m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) if len(m_filter) > 0: logging.warning("clip missing: {}".format(m)) else: logging.debug("clip missing: {}".format(m)) if len(u) > 0: logging.debug("clip unexpected {}:".format(u)) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: logging.debug("left over keys: {}".format(left_over)) if output_model: model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=mm.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") mm.load_models_gpu([model_patcher], force_full_load=True) return (model_patcher, clip, vae) class DiffusionModelSelector(): @classmethod def INPUT_TYPES(s): return {"required": { "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), }, } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("model_path",) FUNCTION = "get_path" DESCRIPTION = "Returns the path to the model as a string." EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" def get_path(self, model_name): model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) return (model_path,) class DiffusionModelLoaderKJ(BaseLoaderKJ): @classmethod def INPUT_TYPES(s): return {"required": { "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],), "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}), "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), }, "optional": { "extra_state_dict": ("STRING", {"forceInput": True, "tooltip": "The full path to an additional state dict to load, this will be merged with the main state dict. Useful for example to add VACE module to a WanVideoModel. You can use DiffusionModelSelector to easily get the path."}), } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch_and_load" DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear." EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None): DTYPE_MAP = { "fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e5m2": torch.float8_e5m2, "fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32 } model_options = {} if dtype := DTYPE_MAP.get(weight_dtype): model_options["dtype"] = dtype logging.info(f"Setting {model_name} weight dtype to {dtype}") if weight_dtype == "fp8_e4m3fn_fast": model_options["dtype"] = torch.float8_e4m3fn model_options["fp8_optimizations"] = True if enable_fp16_accumulation: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = True else: raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.1 or higher") else: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = False unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) sd = comfy.utils.load_torch_file(unet_path) if extra_state_dict is not None: # If the model is a checkpoint, strip additional non-diffusion model entries before adding extra state dict from comfy import model_detection diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) if diffusion_model_prefix == "model.diffusion_model.": temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd extra_sd = comfy.utils.load_torch_file(extra_state_dict) sd.update(extra_sd) del extra_sd model = comfy.sd.load_diffusion_model_state_dict(sd, model_options=model_options) if dtype := DTYPE_MAP.get(compute_dtype): model.set_model_compute_dtype(dtype) model.force_cast_weights = False logging.info(f"Setting {model_name} compute dtype to {dtype}") if sage_attention != "disabled": new_attention = get_sage_func(sage_attention) def attention_override_sage(func, *args, **kwargs): return new_attention.__wrapped__(*args, **kwargs) # attention override model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage return (model,) class ModelPatchTorchSettings: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" DESCRIPTION = "Adds callbacks to model to set torch settings before and after running the model." EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" def patch(self, model, enable_fp16_accumulation): model_clone = model.clone() def patch_enable_fp16_accum(model): logging.info("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = True") torch.backends.cuda.matmul.allow_fp16_accumulation = True def patch_disable_fp16_accum(model): logging.info("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = False") torch.backends.cuda.matmul.allow_fp16_accumulation = False if enable_fp16_accumulation: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_enable_fp16_accum) model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_disable_fp16_accum) else: raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.1 or higher") else: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_disable_fp16_accum) else: raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.1 or higher") return (model_clone,) def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): with self.use_ejected(): device_to = mm.get_torch_device() full_load_override = getattr(self.model, "full_load_override", "auto") if full_load_override in ["enabled", "disabled"]: full_load = full_load_override == "enabled" else: full_load = lowvram_model_memory == 0 self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) for k in self.object_patches: old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old self.inject_model() return self.model def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip): patch_keys = list(model.object_patches_backup.keys()) for k in patch_keys: #print("backing up object patch: ", k) comfy.utils.set_attr(model.model, k, model.object_patches_backup[k]) key_map = {} if model is not None: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) if clip is not None: key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) lora = comfy.lora_convert.convert_lora(lora) loaded = comfy.lora.load_lora(lora, key_map) #print(temp_object_patches_backup) if model is not None: new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) else: k = () new_modelpatcher = None if clip is not None: new_clip = clip.clone() k1 = new_clip.add_patches(loaded, strength_clip) else: k1 = () new_clip = None k = set(k) k1 = set(k1) for x in loaded: if (x not in k) and (x not in k1): logging.warning("NOT LOADED {}".format(x)) if patch_keys: if hasattr(model.model, "compile_settings"): compile_settings = getattr(model.model, "compile_settings") logging.info("compile_settings: ", compile_settings) for k in patch_keys: if "diffusion_model." in k: # Remove the prefix to get the attribute path key = k.replace('diffusion_model.', '') attributes = key.split('.') # Start with the diffusion_model object block = model.get_model_object("diffusion_model") # Navigate through the attributes to get to the block for attr in attributes: if attr.isdigit(): block = block[int(attr)] else: block = getattr(block, attr) # Compile the block compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"]) # Add the compiled block back as an object patch model.add_object_patch(k, compiled_block) return (new_modelpatcher, new_clip) class PatchModelPatcherOrder: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}), "full_load": (["enabled", "disabled", "auto"], {"default": "auto", "tooltip": "Disabling may help with memory issues when loading large models, when changing this you should probably force model reload to avoid issues!"}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/experimental" DESCRIPTION = "NO LONGER NECESSARY, keeping node for backwards compatibility. Use the v2 compile nodes to use LoRA with torch.compile." DEPRECATED = True def patch(self, model, patch_order, full_load): comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {} setattr(model.model, "full_load_override", full_load) if patch_order == "weight_patch_first": comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model comfy.sd.load_lora_for_models = patched_load_lora_for_models else: comfy.model_patcher.ModelPatcher.patch_model = _original_functions.get("original_patch_model") comfy.sd.load_lora_for_models = _original_functions.get("original_load_lora_for_models") return model, class TorchCompileModelFluxAdvanced: def __init__(self): self._compiled = False @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "backend": (["inductor", "cudagraphs"],), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "double_blocks": ("STRING", {"default": "0-18", "multiline": True}), "single_blocks": ("STRING", {"default": "0-37", "multiline": True}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), }, "optional": { "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True DEPRECATED = True def parse_blocks(self, blocks_str): blocks = [] for part in blocks_str.split(','): part = part.strip() if '-' in part: start, end = map(int, part.split('-')) blocks.extend(range(start, end + 1)) else: blocks.append(int(part)) return blocks def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic, dynamo_cache_size_limit): single_block_list = self.parse_blocks(single_blocks) double_block_list = self.parse_blocks(double_blocks) m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit if not self._compiled: try: for i, block in enumerate(diffusion_model.double_blocks): if i in double_block_list: #print("Compiling double_block", i) m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)) for i, block in enumerate(diffusion_model.single_blocks): if i in single_block_list: #print("Compiling single block", i) m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)) self._compiled = True compile_settings = { "backend": backend, "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic, } setattr(m.model, "compile_settings", compile_settings) except: raise RuntimeError("Failed to compile model") return (m, ) # rest of the layers that are not patched # diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend) class TorchCompileModelFluxAdvancedV2: def __init__(self): self._compiled = False @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "backend": (["inductor", "cudagraphs"],), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}), "single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), }, "optional": { "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "force_parameter_static_shapes": ("BOOLEAN", {"default": True, "tooltip": "torch._dynamo.config.force_parameter_static_shapes"}), } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic, dynamo_cache_size_limit=64, force_parameter_static_shapes=True): from comfy_api.torch_helpers import set_torch_compile_wrapper m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit torch._dynamo.config.force_parameter_static_shapes = force_parameter_static_shapes compile_key_list = [] try: if double_blocks: for i, block in enumerate(diffusion_model.double_blocks): print("Adding double block to compile list", i) compile_key_list.append(f"diffusion_model.double_blocks.{i}") if single_blocks: for i, block in enumerate(diffusion_model.single_blocks): compile_key_list.append(f"diffusion_model.single_blocks.{i}") set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) except: raise RuntimeError("Failed to compile model") return (m, ) # rest of the layers that are not patched # diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend) class TorchCompileModelHyVideo: def __init__(self): self._compiled = False @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "compile_single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}), "compile_double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}), "compile_txt_in": ("BOOLEAN", {"default": False, "tooltip": "Compile txt_in layers"}), "compile_vector_in": ("BOOLEAN", {"default": False, "tooltip": "Compile vector_in layers"}), "compile_final_layer": ("BOOLEAN", {"default": False, "tooltip": "Compile final layer"}), }, } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" DEPRECATED = True CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks, compile_txt_in, compile_vector_in, compile_final_layer): m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit if not self._compiled: try: if compile_single_blocks: for i, block in enumerate(diffusion_model.single_blocks): compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch(f"diffusion_model.single_blocks.{i}", compiled_block) if compile_double_blocks: for i, block in enumerate(diffusion_model.double_blocks): compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch(f"diffusion_model.double_blocks.{i}", compiled_block) if compile_txt_in: compiled_block = torch.compile(diffusion_model.txt_in, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch("diffusion_model.txt_in", compiled_block) if compile_vector_in: compiled_block = torch.compile(diffusion_model.vector_in, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch("diffusion_model.vector_in", compiled_block) if compile_final_layer: compiled_block = torch.compile(diffusion_model.final_layer, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch("diffusion_model.final_layer", compiled_block) self._compiled = True compile_settings = { "backend": backend, "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic, } setattr(m.model, "compile_settings", compile_settings) except: raise RuntimeError("Failed to compile model") return (m, ) class TorchCompileModelWanVideo: def __init__(self): self._compiled = False @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "compile_transformer_blocks_only": ("BOOLEAN", {"default": False, "tooltip": "Compile only transformer blocks"}), }, } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True DEPRECATED = True def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only): m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit try: if compile_transformer_blocks_only: for i, block in enumerate(diffusion_model.blocks): if hasattr(block, "_orig_mod"): block = block._orig_mod compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block) else: compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch("diffusion_model", compiled_model) compile_settings = { "backend": backend, "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic, } setattr(m.model, "compile_settings", compile_settings) except: raise RuntimeError("Failed to compile model") return (m, ) class TorchCompileModelWanVideoV2: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), }, "optional": { "force_parameter_static_shapes": ("BOOLEAN", {"default": True, "tooltip": "torch._dynamo.config.force_parameter_static_shapes"}), }, } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, force_parameter_static_shapes=True): from comfy_api.torch_helpers import set_torch_compile_wrapper m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit torch._dynamo.config.force_parameter_static_shapes = force_parameter_static_shapes try: if compile_transformer_blocks_only: compile_key_list = [] for i, block in enumerate(diffusion_model.blocks): compile_key_list.append(f"diffusion_model.blocks.{i}") else: compile_key_list =["diffusion_model"] set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) except: raise RuntimeError("Failed to compile model") return (m, ) class TorchCompileModelAdvanced: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "debug_compile_keys": ("BOOLEAN", {"default": False, "tooltip": "Print the compile keys used for torch.compile"}), }, } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" DESCRIPTION = "Advanced torch.compile patching for diffusion models." EXPERIMENTAL = True def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, debug_compile_keys): from comfy_api.torch_helpers import set_torch_compile_wrapper m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit try: if compile_transformer_blocks_only: layer_types = ["double_blocks", "single_blocks", "layers", "transformer_blocks", "blocks", "visual_transformer_blocks", "text_transformer_blocks"] compile_key_list = [] for layer_name in layer_types: if hasattr(diffusion_model, layer_name): blocks = getattr(diffusion_model, layer_name) for i in range(len(blocks)): compile_key_list.append(f"diffusion_model.{layer_name}.{i}") if not compile_key_list: logging.warning("No known transformer blocks found to compile, compiling entire diffusion model instead") elif debug_compile_keys: logging.info("TorchCompileModelAdvanced: Compile key list:") for key in compile_key_list: logging.info(f" - {key}") if not compile_key_list: compile_key_list =["diffusion_model"] set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) except: raise RuntimeError("Failed to compile model") return (m, ) class TorchCompileModelQwenImage: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), }, } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only): from comfy_api.torch_helpers import set_torch_compile_wrapper m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit try: if compile_transformer_blocks_only: compile_key_list = [] for i, block in enumerate(diffusion_model.transformer_blocks): compile_key_list.append(f"diffusion_model.transformer_blocks.{i}") else: compile_key_list =["diffusion_model"] set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) except: raise RuntimeError("Failed to compile model") return (m, ) class TorchCompileVAE: def __init__(self): self._compiled_encoder = False self._compiled_decoder = False @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("VAE",), "backend": (["inductor", "cudagraphs"],), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "compile_encoder": ("BOOLEAN", {"default": True, "tooltip": "Compile encoder"}), "compile_decoder": ("BOOLEAN", {"default": True, "tooltip": "Compile decoder"}), }} RETURN_TYPES = ("VAE",) FUNCTION = "compile" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder): if compile_encoder: if not self._compiled_encoder: encoder_name = "encoder" if hasattr(vae.first_stage_model, "taesd_encoder"): encoder_name = "taesd_encoder" try: setattr( vae.first_stage_model, encoder_name, torch.compile( getattr(vae.first_stage_model, encoder_name), mode=mode, fullgraph=fullgraph, backend=backend, ), ) self._compiled_encoder = True except: raise RuntimeError("Failed to compile model") if compile_decoder: if not self._compiled_decoder: decoder_name = "decoder" if hasattr(vae.first_stage_model, "taesd_decoder"): decoder_name = "taesd_decoder" try: setattr( vae.first_stage_model, decoder_name, torch.compile( getattr(vae.first_stage_model, decoder_name), mode=mode, fullgraph=fullgraph, backend=backend, ), ) self._compiled_decoder = True except: raise RuntimeError("Failed to compile model") return (vae, ) class TorchCompileControlNet: def __init__(self): self._compiled= False @classmethod def INPUT_TYPES(s): return {"required": { "controlnet": ("CONTROL_NET",), "backend": (["inductor", "cudagraphs"],), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), }} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "compile" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def compile(self, controlnet, backend, mode, fullgraph): if not self._compiled: try: # for i, block in enumerate(controlnet.control_model.double_blocks): # print("Compiling controlnet double_block", i) # controlnet.control_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend) controlnet.control_model = torch.compile(controlnet.control_model, mode=mode, fullgraph=fullgraph, backend=backend) self._compiled = True except: self._compiled = False raise RuntimeError("Failed to compile model") return (controlnet, ) class TorchCompileLTXModel: def __init__(self): self._compiled = False @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "backend": (["inductor", "cudagraphs"],), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, mode, fullgraph, dynamic): m = model.clone() diffusion_model = m.get_model_object("diffusion_model") if not self._compiled: try: for i, block in enumerate(diffusion_model.transformer_blocks): compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend) m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block) self._compiled = True compile_settings = { "backend": backend, "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic, } setattr(m.model, "compile_settings", compile_settings) except: raise RuntimeError("Failed to compile model") return (m, ) class TorchCompileCosmosModel: def __init__(self): self._compiled = False @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "backend": (["inductor", "cudagraphs"],), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "tooltip": "Set the dynamo cache size limit"}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, mode, fullgraph, dynamic, dynamo_cache_size_limit): m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit if not self._compiled: try: for name, block in diffusion_model.blocks.items(): #print(f"Compiling block {name}") compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend) m.add_object_patch(f"diffusion_model.blocks.{name}", compiled_block) #diffusion_model.blocks[name] = compiled_block self._compiled = True compile_settings = { "backend": backend, "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic, } setattr(m.model, "compile_settings", compile_settings) except: raise RuntimeError("Failed to compile model") return (m, ) #teacache try: from comfy.ldm.wan.model import sinusoidal_embedding_1d except: pass from einops import repeat from unittest.mock import patch from contextlib import nullcontext import numpy as np def relative_l1_distance(last_tensor, current_tensor): l1_distance = torch.abs(last_tensor - current_tensor).mean() norm = torch.abs(last_tensor).mean() relative_l1_distance = l1_distance / norm return relative_l1_distance.to(torch.float32) @torch.compiler.disable() def tea_cache(self, x, e0, e, transformer_options): #teacache for cond and uncond separately rel_l1_thresh = transformer_options["rel_l1_thresh"] is_cond = True if transformer_options["cond_or_uncond"] == [0] else False should_calc = True suffix = "cond" if is_cond else "uncond" # Init cache dict if not exists if not hasattr(self, 'teacache_state'): self.teacache_state = { 'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, 'teacache_skipped_steps': 0, 'previous_residual': None}, 'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, 'teacache_skipped_steps': 0, 'previous_residual': None} } logging.info("\nTeaCache: Initialized") cache = self.teacache_state[suffix] if cache['prev_input'] is not None: if transformer_options["coefficients"] == []: temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 else: rescale_func = np.poly1d(transformer_options["coefficients"]) curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item()) try: if curr_acc_dist < rel_l1_thresh: should_calc = False cache['accumulated_rel_l1_distance'] = curr_acc_dist else: should_calc = True cache['accumulated_rel_l1_distance'] = 0 except: should_calc = True cache['accumulated_rel_l1_distance'] = 0 if transformer_options["coefficients"] == []: cache['prev_input'] = e0.clone().detach() else: cache['prev_input'] = e.clone().detach() if not should_calc: x += cache['previous_residual'].to(x.device) cache['teacache_skipped_steps'] += 1 #print(f"TeaCache: Skipping {suffix} step") return should_calc, cache def teacache_wanvideo_vace_forward_orig(self, x, t, context, vace_context, vace_strength, clip_fea=None, freqs=None, transformer_options={}, **kwargs): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2) # time embeddings e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context = self.text_embedding(context) context_img_len = None if clip_fea is not None: if self.img_emb is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) context_img_len = clip_fea.shape[-2] orig_shape = list(vace_context.shape) vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:]) c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) c = c.flatten(2).transpose(1, 2) c = list(c.split(orig_shape[0], dim=0)) if not transformer_options: raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later") teacache_enabled = transformer_options.get("teacache_enabled", False) if not teacache_enabled: should_calc = True else: should_calc, cache = tea_cache(self, x, e0, e, transformer_options) if should_calc: original_x = x.clone().detach() patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) ii = self.vace_layers_mapping.get(i, None) if ii is not None: for iii in range(len(c)): c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=original_x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) x += c_skip * vace_strength[iii] del c_skip if teacache_enabled: cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"]) # head x = self.head(x, e) # unpatchify x = self.unpatchify(x, grid_sizes) return x def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2) # time embeddings e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context = self.text_embedding(context) context_img_len = None if clip_fea is not None: if self.img_emb is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) context_img_len = clip_fea.shape[-2] teacache_enabled = transformer_options.get("teacache_enabled", False) if not teacache_enabled: should_calc = True else: should_calc, cache = tea_cache(self, x, e0, e, transformer_options) if should_calc: original_x = x.clone().detach() patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) if teacache_enabled: cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"]) # head x = self.head(x, e) # unpatchify x = self.unpatchify(x, grid_sizes) return x class WanVideoTeaCacheKJ: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "rel_l1_thresh": ("FLOAT", {"default": 0.275, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy. When using coefficients a good value range is something between 0.2-0.4 for all but 1.3B model, which should be about 10 times smaller, same as when not using coefficients."}), "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}), "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), "coefficients": (["disabled", "1.3B", "14B", "i2v_480", "i2v_720"], {"default": "i2v_480", "tooltip": "Coefficients for rescaling the relative l1 distance, if disabled the threshold value should be about 10 times smaller than the value used with coefficients."}), } } RETURN_TYPES = ("MODEL",) RETURN_NAMES = ("model",) FUNCTION = "patch_teacache" CATEGORY = "KJNodes/teacache" DEPRECATED = True DESCRIPTION = """ Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and applying it instead of doing the step. Best results are achieved by choosing the appropriate coefficients for the model. Early steps should never be skipped, with too aggressive values this can happen and the motion suffers. Starting later can help with that too. When NOT using coefficients, the threshold value should be about 10 times smaller than the value used with coefficients. Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4Wan2.1:
+-------------------+--------+---------+--------+
|       Model       |  Low   | Medium  |  High  |
+-------------------+--------+---------+--------+
| Wan2.1 t2v 1.3B  |  0.05  |  0.07   |  0.08  |
| Wan2.1 t2v 14B   |  0.14  |  0.15   |  0.20  |
| Wan2.1 i2v 480P  |  0.13  |  0.19   |  0.26  |
| Wan2.1 i2v 720P  |  0.18  |  0.20   |  0.30  |
+-------------------+--------+---------+--------+
""" EXPERIMENTAL = True def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device, coefficients): if rel_l1_thresh == 0: return (model,) if coefficients == "disabled" and rel_l1_thresh > 0.1: logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.") if coefficients != "disabled" and rel_l1_thresh < 0.1 and "1.3B" not in coefficients: logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.") # type_str = str(type(model.model.model_config).__name__) #if model.model.diffusion_model.dim == 1536: # model_type ="1.3B" # else: # if "WAN21_T2V" in type_str: # model_type = "14B" # elif "WAN21_I2V" in type_str: # model_type = "i2v_480" # else: # model_type = "i2v_720" #how to detect this? teacache_coefficients_map = { "disabled": [], "1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], "14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], "i2v_480": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], "i2v_720": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], } coefficients = teacache_coefficients_map[coefficients] teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device() model_clone = model.clone() if 'transformer_options' not in model_clone.model_options: model_clone.model_options['transformer_options'] = {} model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device model_clone.model_options["transformer_options"]["coefficients"] = coefficients diffusion_model = model_clone.get_model_object("diffusion_model") def outer_wrapper(start_percent, end_percent): def unet_wrapper_function(model_function, kwargs): input = kwargs["input"] timestep = kwargs["timestep"] c = kwargs["c"] sigmas = c["transformer_options"]["sample_sigmas"] cond_or_uncond = kwargs["cond_or_uncond"] last_step = (len(sigmas) - 1) matched_step_index = (sigmas == timestep[0] ).nonzero() if len(matched_step_index) > 0: current_step_index = matched_step_index.item() else: for i in range(len(sigmas) - 1): # walk from beginning of steps until crossing the timestep if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: current_step_index = i break else: current_step_index = 0 if current_step_index == 0: if (len(cond_or_uncond) == 1 and cond_or_uncond[0] == 1) or len(cond_or_uncond) == 2: if hasattr(diffusion_model, "teacache_state"): delattr(diffusion_model, "teacache_state") logging.info("\nResetting TeaCache state") current_percent = current_step_index / (len(sigmas) - 1) c["transformer_options"]["current_percent"] = current_percent if start_percent <= current_percent <= end_percent: c["transformer_options"]["teacache_enabled"] = True forward_function = teacache_wanvideo_vace_forward_orig if hasattr(diffusion_model, "vace_layers") else teacache_wanvideo_forward_orig context = patch.multiple( diffusion_model, forward_orig=forward_function.__get__(diffusion_model, diffusion_model.__class__) ) with context: out = model_function(input, timestep, **c) if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"): if len(cond_or_uncond) == 1 and cond_or_uncond[0] == 0: skipped_steps_cond = diffusion_model.teacache_state["cond"]["teacache_skipped_steps"] skipped_steps_uncond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"] logging.info("-----------------------------------") logging.info(f"TeaCache skipped:") logging.info(f"{skipped_steps_cond} cond steps") logging.info(f"{skipped_steps_uncond} uncond step") logging.info(f"out of {last_step} steps") logging.info("-----------------------------------") elif len(cond_or_uncond) == 2: skipped_steps_cond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"] logging.info("-----------------------------------") logging.info(f"TeaCache skipped:") logging.info(f"{skipped_steps_cond} cond steps") logging.info(f"out of {last_step} steps") logging.info("-----------------------------------") return out return unet_wrapper_function model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent)) return (model_clone,) from comfy.ldm.flux.math import apply_rope def modified_wan_self_attention_forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) k = self.norm_k(self.k(x)).view(b, s, n, d) v = self.v(x).view(b, s, n * d) return q, k, v q, k, v = qkv_fn(x) q, k = apply_rope(q, k, freqs) feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight) try: x = comfy.ldm.modules.attention.optimized_attention( q.view(b, s, n * d), k.view(b, s, n * d), v, heads=self.num_heads, transformer_options=transformer_options, ) except: # backward compatibility for now x = comfy.ldm.modules.attention.attention( q.view(b, s, n * d), k.view(b, s, n * d), v, heads=self.num_heads, ) x = self.o(x) x *= feta_scores return x from einops import rearrange def get_feta_scores(query, key, num_frames, enhance_weight): img_q, img_k = query, key #torch.Size([2, 9216, 12, 128]) _, ST, num_heads, head_dim = img_q.shape spatial_dim = ST / num_frames spatial_dim = int(spatial_dim) query_image = rearrange( img_q, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim ) key_image = rearrange( img_k, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim ) return feta_score(query_image, key_image, head_dim, num_frames, enhance_weight) def feta_score(query_image, key_image, head_dim, num_frames, enhance_weight): scale = head_dim**-0.5 query_image = query_image * scale attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 attn_temp = attn_temp.to(torch.float32) attn_temp = attn_temp.softmax(dim=-1) # Reshape to [batch_size * num_tokens, num_frames, num_frames] attn_temp = attn_temp.reshape(-1, num_frames, num_frames) # Create a mask for diagonal elements diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) # Zero out diagonal elements attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) # Calculate mean for each token's attention matrix # Number of off-diagonal elements per matrix is n*n - n num_off_diag = num_frames * num_frames - num_frames mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag enhance_scores = mean_scores.mean() * (num_frames + enhance_weight) enhance_scores = enhance_scores.clamp(min=1) return enhance_scores import types class WanAttentionPatch: def __init__(self, num_frames, weight): self.num_frames = num_frames self.enhance_weight = weight def __get__(self, obj, objtype=None): # Create bound method with stored parameters def wrapped_attention(self_module, *args, **kwargs): self_module.num_frames = self.num_frames self_module.enhance_weight = self.enhance_weight return modified_wan_self_attention_forward(self_module, *args, **kwargs) return types.MethodType(wrapped_attention, obj) class WanVideoEnhanceAVideoKJ: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}), "weight": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of the enhance effect"}), } } RETURN_TYPES = ("MODEL",) RETURN_NAMES = ("model",) FUNCTION = "enhance" CATEGORY = "KJNodes/experimental" DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" EXPERIMENTAL = True def enhance(self, model, weight, latent): if weight == 0: return (model,) num_frames = latent["samples"].shape[2] model_clone = model.clone() if 'transformer_options' not in model_clone.model_options: model_clone.model_options['transformer_options'] = {} model_clone.model_options["transformer_options"]["enhance_weight"] = weight diffusion_model = model_clone.get_model_object("diffusion_model") compile_settings = getattr(model.model, "compile_settings", None) for idx, block in enumerate(diffusion_model.blocks): patched_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__) if compile_settings is not None: patched_attn = torch.compile(patched_attn, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"]) model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn) return (model_clone,) def normalized_attention_guidance(self, query, context_positive, context_negative, transformer_options={}): k_positive = self.norm_k(self.k(context_positive)) v_positive = self.v(context_positive) k_negative = self.norm_k(self.k(context_negative)) v_negative = self.v(context_negative) try: x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads, transformer_options=transformer_options).flatten(2) x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads, transformer_options=transformer_options).flatten(2) except: #backwards compatibility for now x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2) x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2) nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1) norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True).expand_as(x_positive) norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True).expand_as(nag_guidance) scale = torch.nan_to_num(norm_guidance / norm_positive, nan=10.0) mask = scale > self.nag_tau adjustment = (norm_positive * self.nag_tau) / (norm_guidance + 1e-7) nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance) x = nag_guidance * self.nag_alpha + x_positive * (1 - self.nag_alpha) del nag_guidance return x #region NAG def wan_crossattn_forward_nag(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ # Determine batch splitting and context handling if self.input_type == "default": # Single or [pos, neg] pair if context.shape[0] == 1: x_pos, context_pos = x, context x_neg, context_neg = None, None else: x_pos, x_neg = torch.chunk(x, 2, dim=0) context_pos, context_neg = torch.chunk(context, 2, dim=0) elif self.input_type == "batch": # Standard batch, no CFG x_pos, context_pos = x, context x_neg, context_neg = None, None # Positive branch q_pos = self.norm_q(self.q(x_pos)) nag_context = self.nag_context if self.input_type == "batch": nag_context = nag_context.repeat(x_pos.shape[0], 1, 1) try: x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context, transformer_options=transformer_options) except: #backwards compatibility for now x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context) # Negative branch if x_neg is not None and context_neg is not None: q_neg = self.norm_q(self.q(x_neg)) k_neg = self.norm_k(self.k(context_neg)) v_neg = self.v(context_neg) try: x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads, transformer_options=transformer_options) except: #backwards compatibility for now x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads) x = torch.cat([x_pos_out, x_neg_out], dim=0) else: x = x_pos_out return self.o(x) def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ context_img = context[:, :context_img_len] context = context[:, context_img_len:] q_img = self.norm_q(self.q(x)) k_img = self.norm_k_img(self.k_img(context_img)) v_img = self.v_img(context_img) try: img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options) except: #backwards compatibility for now img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads) if context.shape[0] == 2: x, x_real_negative = torch.chunk(x, 2, dim=0) context_positive, context_negative = torch.chunk(context, 2, dim=0) else: context_positive = context context_negative = None q = self.norm_q(self.q(x)) x = normalized_attention_guidance(self, q, context_positive, self.nag_context, transformer_options=transformer_options) if context_negative is not None: q_real_negative = self.norm_q(self.q(x_real_negative)) k_real_negative = self.norm_k(self.k(context_negative)) v_real_negative = self.v(context_negative) try: x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads, transformer_options=transformer_options) except: #backwards compatibility for now x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads) x = torch.cat([x, x_real_negative], dim=0) # output x = x + img_x x = self.o(x) return x class WanCrossAttentionPatch: def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False, input_type="default"): self.nag_context = context self.nag_scale = nag_scale self.nag_alpha = nag_alpha self.nag_tau = nag_tau self.i2v = i2v self.input_type = input_type def __get__(self, obj, objtype=None): # Create bound method with stored parameters def wrapped_attention(self_module, *args, **kwargs): self_module.nag_context = self.nag_context self_module.nag_scale = self.nag_scale self_module.nag_alpha = self.nag_alpha self_module.nag_tau = self.nag_tau self_module.input_type = self.input_type if self.i2v: return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs) else: return wan_crossattn_forward_nag(self_module, *args, **kwargs) return types.MethodType(wrapped_attention, obj) class WanVideoNAG: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "conditioning": ("CONDITIONING",), "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}), "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}), "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}), }, "optional": { "input_type": (["default", "batch"], {"tooltip": "Type of the model input"}), }, } RETURN_TYPES = ("MODEL",) RETURN_NAMES = ("model",) FUNCTION = "patch" CATEGORY = "KJNodes/experimental" DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance" EXPERIMENTAL = True def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau, input_type="default"): if nag_scale == 0: return (model,) device = mm.get_torch_device() dtype = mm.unet_dtype() model_clone = model.clone() diffusion_model = model_clone.get_model_object("diffusion_model") diffusion_model.text_embedding.to(device) context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype)) type_str = str(type(model.model.model_config).__name__) i2v = True if "WAN21_I2V" in type_str else False for idx, block in enumerate(diffusion_model.blocks): patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v, input_type=input_type).__get__(block.cross_attn, block.__class__) model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn) return (model_clone,) class SkipLayerGuidanceWanVideo: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), "blocks": ("STRING", {"default": "10", "multiline": False}), "start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "slg" EXPERIMENTAL = True DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks" DEPRECATED = True CATEGORY = "advanced/guidance" def slg(self, model, start_percent, end_percent, blocks): def skip(args, extra_args): transformer_options = extra_args.get("transformer_options", {}) original_block = extra_args["original_block"] if not transformer_options: raise ValueError("transformer_options not found in extra_args, currently SkipLayerGuidanceWanVideo only works with TeaCacheKJ") if start_percent <= transformer_options["current_percent"] <= end_percent: if args["img"].shape[0] == 2: prev_img_uncond = args["img"][0].unsqueeze(0) new_args = { "img": args["img"][1].unsqueeze(0), "txt": args["txt"][1].unsqueeze(0), "vec": args["vec"][1].unsqueeze(0), "pe": args["pe"][1].unsqueeze(0) } block_out = original_block(new_args) out = { "img": torch.cat([prev_img_uncond, block_out["img"]], dim=0), "txt": args["txt"], "vec": args["vec"], "pe": args["pe"] } else: if transformer_options.get("cond_or_uncond") == [0]: out = original_block(args) else: out = args else: out = original_block(args) return out block_list = [int(x.strip()) for x in blocks.split(",")] blocks = [int(i) for i in block_list] logging.info(f"Selected blocks to skip uncond on: {blocks}") m = model.clone() for b in blocks: #m.set_model_patch_replace(skip, "dit", "double_block", b) model_options = m.model_options["transformer_options"].copy() if "patches_replace" not in model_options: model_options["patches_replace"] = {} else: model_options["patches_replace"] = model_options["patches_replace"].copy() if "dit" not in model_options["patches_replace"]: model_options["patches_replace"]["dit"] = {} else: model_options["patches_replace"]["dit"] = model_options["patches_replace"]["dit"].copy() block = ("double_block", b) model_options["patches_replace"]["dit"][block] = skip m.model_options["transformer_options"] = model_options return (m, ) class CFGZeroStarAndInit: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "use_zero_init": ("BOOLEAN", {"default": True}), "zero_init_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" DESCRIPTION = "https://github.com/WeichenFan/CFG-Zero-star" CATEGORY = "KJNodes/experimental" EXPERIMENTAL = True def patch(self, model, use_zero_init, zero_init_steps): def cfg_zerostar(args): #zero init cond = args["cond"] timestep = args["timestep"] sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] matched_step_index = (sigmas == timestep[0]).nonzero() if len(matched_step_index) > 0: current_step_index = matched_step_index.item() else: for i in range(len(sigmas) - 1): if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: current_step_index = i break else: current_step_index = 0 if (current_step_index <= zero_init_steps) and use_zero_init: return cond * 0 uncond = args["uncond"] cond_scale = args["cond_scale"] batch_size = cond.shape[0] positive_flat = cond.view(batch_size, -1) negative_flat = uncond.view(batch_size, -1) dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 alpha = dot_product / squared_norm alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1))) noise_pred = uncond * alpha + cond_scale * (cond - uncond * alpha) return noise_pred m = model.clone() m.set_model_sampler_cfg_function(cfg_zerostar) return (m, ) if v3_available: class GGUFLoaderKJ(io.ComfyNode): @classmethod def define_schema(cls): # Get GGUF models safely, fallback to empty list if unet_gguf folder doesn't exist try: gguf_models = folder_paths.get_filename_list("unet_gguf") except KeyError: gguf_models = [] return io.Schema( node_id="GGUFLoaderKJ", category="KJNodes/experimental", description="Loads a GGUF model with advanced options, requires [ComfyUI-GGUF](https://github.com/city96/ComfyUI-GGUF) to be installed.", is_experimental=True, inputs=[ io.Combo.Input("model_name", options=gguf_models), io.Combo.Input("extra_model_name", options=gguf_models + ["none"], default="none", tooltip="An extra gguf model to load and merge into the main model, for example VACE module"), io.Combo.Input("dequant_dtype", options=["default", "target", "float32", "float16", "bfloat16"], default="default"), io.Combo.Input("patch_dtype", options=["default", "target", "float32", "float16", "bfloat16"], default="default"), io.Boolean.Input("patch_on_device", default=False), io.Boolean.Input("enable_fp16_accumulation", default=False, tooltip="Enable torch.backends.cuda.matmul.allow_fp16_accumulation, required minimum pytorch version 2.7.1"), io.Combo.Input("attention_override", options=["none", "sdpa", "sageattn", "xformers", "flashattn"], default="none", tooltip="Overrides the used attention implementation, requires the respective library to be installed"), ], outputs=[io.Model.Output(),], ) def attention_override_pytorch(func, *args, **kwargs): new_attention = comfy.ldm.modules.attention.attention_pytorch return new_attention.__wrapped__(*args, **kwargs) def attention_override_sage(func, *args, **kwargs): new_attention = comfy.ldm.modules.attention.attention_sage return new_attention.__wrapped__(*args, **kwargs) def attention_override_xformers(func, *args, **kwargs): new_attention = comfy.ldm.modules.attention.attention_xformers return new_attention.__wrapped__(*args, **kwargs) def attention_override_flash(func, *args, **kwargs): new_attention = comfy.ldm.modules.attention.attention_flash return new_attention.__wrapped__(*args, **kwargs) ATTENTION_OVERRIDES = { "sdpa": attention_override_pytorch, "sageattn": attention_override_sage, "xformers": attention_override_xformers, "flashattn": attention_override_flash, } @classmethod def _get_gguf_module(cls): gguf_path = os.path.join(folder_paths.folder_names_and_paths["custom_nodes"][0][0], "ComfyUI-GGUF") """Import GGUF module with version validation""" for module_name in ["ComfyUI-GGUF", "custom_nodes.ComfyUI-GGUF", "comfyui-gguf", "custom_nodes.comfyui-gguf", gguf_path, gguf_path.lower()]: try: module = importlib.import_module(module_name) return module except ImportError: continue raise ImportError( "Compatible ComfyUI-GGUF not found. " "Please install/update from: https://github.com/city96/ComfyUI-GGUF" ) @classmethod def execute(cls, model_name, extra_model_name, dequant_dtype, patch_dtype, patch_on_device, attention_override, enable_fp16_accumulation): gguf_nodes = cls._get_gguf_module() ops = gguf_nodes.ops.GGMLOps() def set_linear_dtype(attr, value): if value == "default": setattr(ops.Linear, attr, None) elif value == "target": setattr(ops.Linear, attr, value) else: setattr(ops.Linear, attr, getattr(torch, value)) set_linear_dtype("dequant_dtype", dequant_dtype) set_linear_dtype("patch_dtype", patch_dtype) # init model model_path = folder_paths.get_full_path("unet", model_name) sd = gguf_nodes.loader.gguf_sd_loader(model_path) if extra_model_name is not None and extra_model_name != "none": if not extra_model_name.endswith(".gguf"): raise ValueError("Extra model must also be a .gguf file") extra_model_full_path = folder_paths.get_full_path("unet", extra_model_name) extra_model = gguf_nodes.loader.gguf_sd_loader(extra_model_full_path) sd.update(extra_model) model = comfy.sd.load_diffusion_model_state_dict( sd, model_options={"custom_operations": ops} ) if model is None: raise RuntimeError(f"ERROR: Could not detect model type of: {model_path}") model = gguf_nodes.nodes.GGUFModelPatcher.clone(model) model.patch_on_device = patch_on_device # attention override if attention_override in cls.ATTENTION_OVERRIDES: model.model_options["transformer_options"]["optimized_attention_override"] = cls.ATTENTION_OVERRIDES[attention_override] if enable_fp16_accumulation: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = True else: raise RuntimeError("Failed to set fp16 accumulation, requires pytorch version 2.7.1 or higher") else: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = False return io.NodeOutput(model,) else: class GGUFLoaderKJ: @classmethod def INPUT_TYPES(s): return {} RETURN_TYPES = () FUNCTION = "" CATEGORY = "" DESCRIPTION = "This node requires newer ComfyUI" try: from torch.nn.attention.flex_attention import flex_attention, BlockMask except: flex_attention = None BlockMask = None class NABLA_AttentionKJ(): @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "latent": ("LATENT", {"tooltip": "Only used to get the latent shape"}), "window_time": ("INT", {"default": 11, "min": 1, "tooltip": "Temporal attention window size"}), "window_width": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}), "window_height": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}), "sparsity": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}), "torch_compile": ("BOOLEAN", {"default": True, "tooltip": "Most likely required for reasonable memory usage"}) }, } RETURN_TYPES = ("MODEL", ) FUNCTION = "patch" DESCRIPTION = "Experimental node for patching attention mode to use NABLA sparse attention for video models, currently only works with Kadinsky5" CATEGORY = "KJNodes/experimental" def patch(self, model, latent, window_time, window_width, window_height, sparsity, torch_compile): if flex_attention is None or BlockMask is None: raise RuntimeError("can't import flex_attention from torch.nn.attention, requires newer pytorch version") model_clone = model.clone() samples = latent["samples"] sparse_params = get_sparse_params(samples, window_time, window_height, window_width, sparsity) nabla_attention = NABLA_Attention(sparse_params) def attention_override_nabla(func, *args, **kwargs): return nabla_attention(*args, **kwargs) if torch_compile: attention_override_nabla = torch.compile(attention_override_nabla, mode="max-autotune-no-cudagraphs", dynamic=True) # attention override model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_nabla return model_clone, class NABLA_Attention(): def __init__(self, sparse_params): self.sparse_params = sparse_params def __call__(self, q, k, v, heads, **kwargs): if q.shape[-2] < 3000 or k.shape[-2] < 3000: return optimized_attention(q, k, v, heads, **kwargs) block_mask = self.nablaT_v2(q, k, self.sparse_params["sta_mask"], thr=self.sparse_params["P"]) out = flex_attention(q, k, v, block_mask=block_mask).transpose(1, 2).contiguous().flatten(-2, -1) return out def nablaT_v2(self, q, k, sta, thr=0.9): # Map estimation BLOCK_SIZE = 64 B, h, S, D = q.shape s1 = S // BLOCK_SIZE qa = q.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2) ka = k.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2).transpose(-2, -1) map = qa @ ka map = torch.softmax(map / math.sqrt(D), dim=-1) # Map binarization vals, inds = map.sort(-1) cvals = vals.cumsum_(-1) mask = (cvals >= 1 - thr).int() mask = mask.gather(-1, inds.argsort(-1)) mask = torch.logical_or(mask, sta) # BlockMask creation kv_nb = mask.sum(-1).to(torch.int32) kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=BLOCK_SIZE, mask_mod=None) def fast_sta_nabla(T, H, W, wT=3, wH=3, wW=3): l = torch.Tensor([T, H, W]).amax() r = torch.arange(0, l, 1, dtype=torch.int16, device=mm.get_torch_device()) mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() sta_t, sta_h, sta_w = ( mat[:T, :T].flatten(), mat[:H, :H].flatten(), mat[:W, :W].flatten(), ) sta_t = sta_t <= wT // 2 sta_h = sta_h <= wH // 2 sta_w = sta_w <= wW // 2 sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten() sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2) return sta.reshape(T * H * W, T * H * W) def get_sparse_params(x, wT, wH, wW, sparsity=0.9): B, C, T, H, W = x.shape #print("x shape:", x.shape) patch_size = (1, 2, 2) T, H, W = ( T // patch_size[0], H // patch_size[1], W // patch_size[2], ) sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW) sparse_params = { "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), "to_fractal": True, "P": sparsity, "wT": wT, "wH": wH, "wW": wW, "add_sta": True, "visual_shape": (T, H, W), "method": "topcdf", } return sparse_params from comfy.comfy_types.node_typing import IO class StartRecordCUDAMemoryHistory(): # @classmethod # def IS_CHANGED(s): # return True @classmethod def INPUT_TYPES(s): return { "required": { "input": (IO.ANY,), "enabled": (["all", "state", "None"], {"default": "all", "tooltip": "None: disable, 'state': keep info for allocated memory, 'all': keep history of all alloc/free calls"}), "context": (["all", "state", "alloc", "None"], {"default": "all", "tooltip": "None: no tracebacks, 'state': tracebacks for allocated memory, 'alloc': for alloc calls, 'all': for free calls"}), "stacks": (["python", "all"], {"default": "all", "tooltip": "'python': Python/TorchScript/inductor frames, 'all': also C++ frames"}), "max_entries": ("INT", {"default": 100000, "min": 1000, "max": 10000000, "tooltip": "Maximum number of entries to record"}), }, } RETURN_TYPES = (IO.ANY, ) RETURN_NAMES = ("input", "output_path",) FUNCTION = "start" CATEGORY = "KJNodes/experimental" DESCRIPTION = "THIS NODE ALWAYS RUNS. Starts recording CUDA memory allocation history, can be ended and saved with EndRecordCUDAMemoryHistory. " def start(self, input, enabled, context, stacks, max_entries): mm.soft_empty_cache() torch.cuda.reset_peak_memory_stats(mm.get_torch_device()) torch.cuda.memory._record_memory_history( max_entries=max_entries, enabled=enabled if enabled != "None" else None, context=context if context != "None" else None, stacks=stacks ) return input, class EndRecordCUDAMemoryHistory(): @classmethod def INPUT_TYPES(s): return {"required": { "input": (IO.ANY,), "output_path": ("STRING", {"default": "comfy_cuda_memory_history"}, "Base path for saving the CUDA memory history file, timestamp and .pt extension will be added"), }, } RETURN_TYPES = (IO.ANY, "STRING",) RETURN_NAMES = ("input", "output_path",) FUNCTION = "end" CATEGORY = "KJNodes/experimental" DESCRIPTION = "Records CUDA memory allocation history between start and end, saves to a file that can be analyzed here: https://docs.pytorch.org/memory_viz or with VisualizeCUDAMemoryHistory node" def end(self, input, output_path): mm.soft_empty_cache() time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_path = f"{output_path}{time}.pt" torch.cuda.memory._dump_snapshot(output_path) torch.cuda.memory._record_memory_history(enabled=None) return input, output_path try: from server import PromptServer except: PromptServer = None class VisualizeCUDAMemoryHistory(): @classmethod def INPUT_TYPES(s): return {"required": { "snapshot_path": ("STRING", ), }, "hidden": { "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("output_path",) FUNCTION = "visualize" CATEGORY = "KJNodes/experimental" DESCRIPTION = "Visualizes a CUDA memory allocation history file, opens in browser" OUTPUT_NODE = True def visualize(self, snapshot_path, unique_id): import pickle from torch.cuda import _memory_viz import uuid from folder_paths import get_output_directory output_dir = get_output_directory() with open(snapshot_path, "rb") as f: snapshot = pickle.load(f) html = _memory_viz.trace_plot(snapshot) html_filename = f"cuda_memory_history_{uuid.uuid4().hex}.html" output_path = os.path.join(output_dir, "memory_history", html_filename) os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: f.write(html) api_url = f"http://localhost:8188/api/view?type=output&filename={html_filename}&subfolder=memory_history" # Progress UI if unique_id and PromptServer is not None: try: PromptServer.instance.send_progress_text( api_url, unique_id ) except: pass return api_url,