Compare commits

...

2 Commits

Author SHA1 Message Date
kijai
390d05fe7e Add generic TorchCompileModelAdvanced node to handle advanced compile options for all diffusion models
Avoids needing different nodes for different models
2025-11-27 13:59:31 +02:00
kijai
f0ed965cd9 Allow fp32 input for sageattn function 2025-11-27 13:33:41 +02:00
2 changed files with 59 additions and 2 deletions

View File

@ -211,6 +211,7 @@ NODE_CONFIG = {
"GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"}, "GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"},
"LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"}, "LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"},
"NABLA_AttentionKJ": {"class": NABLA_AttentionKJ, "name": "NABLA Attention KJ"}, "NABLA_AttentionKJ": {"class": NABLA_AttentionKJ, "name": "NABLA Attention KJ"},
"TorchCompileModelAdvanced": {"class": TorchCompileModelAdvanced, "name": "TorchCompileModelAdvanced"},
#instance diffusion #instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -73,6 +73,9 @@ def get_sage_func(sage_attention, allow_compile=False):
@wrap_attn @wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
in_dtype = v.dtype
if q.dtype == torch.float32 or k.dtype == torch.float32 or v.dtype == torch.float32:
q, k, v = q.to(torch.float16), k.to(torch.float16), v.to(torch.float16)
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout="HND" tensor_layout="HND"
@ -91,7 +94,7 @@ def get_sage_func(sage_attention, allow_compile=False):
# add a heads dimension if there isn't already one # add a heads dimension if there isn't already one
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout).to(in_dtype)
if tensor_layout == "HND": if tensor_layout == "HND":
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (
@ -853,7 +856,60 @@ class TorchCompileModelWanVideoV2:
raise RuntimeError("Failed to compile model") raise RuntimeError("Failed to compile model")
return (m, ) return (m, )
class TorchCompileModelAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
"debug_compile_keys": ("BOOLEAN", {"default": False, "tooltip": "Print the compile keys used for torch.compile"}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/torchcompile"
DESCRIPTION = "Advanced torch.compile patching for diffusion models."
EXPERIMENTAL = True
def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, debug_compile_keys):
from comfy_api.torch_helpers import set_torch_compile_wrapper
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
try:
if compile_transformer_blocks_only:
layer_types = ["double_blocks", "single_blocks", "layers", "transformer_blocks", "blocks"]
compile_key_list = []
for layer_name in layer_types:
if hasattr(diffusion_model, layer_name):
blocks = getattr(diffusion_model, layer_name)
for i in range(len(blocks)):
compile_key_list.append(f"diffusion_model.{layer_name}.{i}")
if not compile_key_list:
logging.warning("No known transformer blocks found to compile, compiling entire diffusion model instead")
elif debug_compile_keys:
logging.info("TorchCompileModelAdvanced: Compile key list:")
for key in compile_key_list:
logging.info(f" - {key}")
if not compile_key_list:
compile_key_list =["diffusion_model"]
set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph)
except:
raise RuntimeError("Failed to compile model")
return (m, )
class TorchCompileModelQwenImage: class TorchCompileModelQwenImage:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):