diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 68fbd55..d30eff4 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1,15 +1,17 @@ import os from comfy.ldm.modules import attention as comfy_attention import logging -import comfy.model_patcher -import comfy.utils -import comfy.sd import torch +import importlib + import folder_paths import comfy.model_management as mm from comfy.cli_args import args -from typing import Optional, Tuple -import importlib +from comfy.ldm.modules.attention import wrap_attn +import comfy.model_patcher +import comfy.utils +import comfy.sd + try: from comfy_api.latest import io v3_available = True @@ -32,161 +34,83 @@ if not _initialized: pass _initialized = True + +def get_sage_func(sage_attention, allow_compile=False): + logging.info(f"Using sage attention mode: {sage_attention}") + from sageattention import sageattn + if sage_attention == "auto": + def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) + elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda": + from sageattention import sageattn_qk_int8_pv_fp16_cuda + def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout) + elif sage_attention == "sageattn_qk_int8_pv_fp16_triton": + from sageattention import sageattn_qk_int8_pv_fp16_triton + def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) + elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda": + from sageattention import sageattn_qk_int8_pv_fp8_cuda + def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout) + elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++": + from sageattention import sageattn_qk_int8_pv_fp8_cuda + def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout) + elif "sageattn3" in sage_attention: + from sageattn3 import sageattn3_blackwell + if sage_attention == "sageattn3_per_block_mean": + def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs): + return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2) + else: + def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs): + return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2) + logging.info(f"Sage attention function: {sage_func}") + + if not allow_compile: + sage_func = torch.compiler.disable()(sage_func) + + @wrap_attn + def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if skip_reshape: + b, _, _, dim_head = q.shape + tensor_layout="HND" + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head), + (q, k, v), + ) + tensor_layout="NHD" + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + if tensor_layout == "HND": + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + else: + if skip_output_reshape: + out = out.transpose(1, 2) + else: + out = out.reshape(b, -1, heads * dim_head) + return out + return attention_sage + class BaseLoaderKJ: original_linear = None cublas_patched = False - @torch.compiler.disable() def _patch_modules(self, patch_cublaslinear, sage_attention): - try: - from comfy.ldm.qwen_image.model import apply_rotary_emb - def qwen_sage_forward( - self, - hidden_states: torch.FloatTensor, # Image stream - encoder_hidden_states: torch.FloatTensor = None, # Text stream - encoder_hidden_states_mask: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - transformer_options={}, - ) -> Tuple[torch.Tensor, torch.Tensor]: - seq_txt = encoder_hidden_states.shape[1] - - img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) - img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) - img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) - - txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - - img_query = self.norm_q(img_query) - img_key = self.norm_k(img_key) - txt_query = self.norm_added_q(txt_query) - txt_key = self.norm_added_k(txt_key) - - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) - - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) - - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) - - joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) - - txt_attn_output = joint_hidden_states[:, :seq_txt, :] - img_attn_output = joint_hidden_states[:, seq_txt:, :] - - img_attn_output = self.to_out[0](img_attn_output) - img_attn_output = self.to_out[1](img_attn_output) - txt_attn_output = self.to_add_out(txt_attn_output) - - return img_attn_output, txt_attn_output - except: - print("Failed to patch QwenImage attention, Comfy not updated, skipping") - from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight - if sage_attention != "disabled": - print("Patching comfy attention to use sageattn") - from sageattention import sageattn - def set_sage_func(sage_attention): - if sage_attention == "auto": - def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): - return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) - return func - elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda": - from sageattention import sageattn_qk_int8_pv_fp16_cuda - def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): - return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout) - return func - elif sage_attention == "sageattn_qk_int8_pv_fp16_triton": - from sageattention import sageattn_qk_int8_pv_fp16_triton - def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): - return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) - return func - elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda": - from sageattention import sageattn_qk_int8_pv_fp8_cuda - def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): - return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout) - return func - elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++": - from sageattention import sageattn_qk_int8_pv_fp8_cuda - def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): - return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout) - return func - elif "sageattn3" in sage_attention: - from sageattn3 import sageattn3_blackwell - if sage_attention == "sageattn3_per_block_mean": - def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): - return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2) - else: - def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): - return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2) - return func - - sage_func = set_sage_func(sage_attention) - - @torch.compiler.disable() - def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, transformer_options=None): - if skip_reshape: - b, _, _, dim_head = q.shape - tensor_layout="HND" - else: - b, _, dim_head = q.shape - dim_head //= heads - q, k, v = map( - lambda t: t.view(b, -1, heads, dim_head), - (q, k, v), - ) - tensor_layout="NHD" - if mask is not None: - # add a batch dimension if there isn't already one - if mask.ndim == 2: - mask = mask.unsqueeze(0) - # add a heads dimension if there isn't already one - if mask.ndim == 3: - mask = mask.unsqueeze(1) - out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) - if tensor_layout == "HND": - if not skip_output_reshape: - out = ( - out.transpose(1, 2).reshape(b, -1, heads * dim_head) - ) - else: - if skip_output_reshape: - out = out.transpose(1, 2) - else: - out = out.reshape(b, -1, heads * dim_head) - return out - - comfy_attention.optimized_attention = attention_sage - comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage - comfy.ldm.flux.math.optimized_attention = attention_sage - comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage - comfy.ldm.cosmos.blocks.optimized_attention = attention_sage - comfy.ldm.wan.model.optimized_attention = attention_sage - try: - comfy.ldm.qwen_image.model.Attention.forward = qwen_sage_forward - except: - pass - - else: - print("Restoring initial comfy attention") - comfy_attention.optimized_attention = _original_functions.get("orig_attention") - comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention") - comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention") - comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention") - comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention") - comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention") - try: - comfy.ldm.qwen_image.model.Attention.forward = _original_functions.get("original_qwen_forward") - except: - pass - if patch_cublaslinear: if not BaseLoaderKJ.cublas_patched: BaseLoaderKJ.original_linear = disable_weight_init.Linear @@ -218,13 +142,17 @@ class BaseLoaderKJ: from comfy.patcher_extension import CallbacksMP -class PathchSageAttentionKJ(BaseLoaderKJ): +class PathchSageAttentionKJ(): @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}), - }} + }, + "optional": { + "allow_compile": ("BOOLEAN", {"default": False, "tooltip": "Allow the use of torch.compile for the new attention function."}) + } + } RETURN_TYPES = ("MODEL", ) FUNCTION = "patch" @@ -232,18 +160,19 @@ class PathchSageAttentionKJ(BaseLoaderKJ): EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def patch(self, model, sage_attention): + def patch(self, model, sage_attention, allow_compile=False): + if sage_attention == "disabled": + return model, + model_clone = model.clone() - @torch.compiler.disable() - def patch_attention_enable(model): - self._patch_modules(False, sage_attention) - @torch.compiler.disable() - def patch_attention_disable(model): - self._patch_modules(False, "disabled") - - model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_attention_enable) - model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_attention_disable) - + + new_attention = get_sage_func(sage_attention, allow_compile=allow_compile) + def attention_override_sage(func, *args, **kwargs): + return new_attention.__wrapped__(*args, **kwargs) + + # attention override + model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage + return model_clone, class CheckpointLoaderKJ(BaseLoaderKJ): @@ -306,9 +235,14 @@ class CheckpointLoaderKJ(BaseLoaderKJ): if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = False - def patch_attention(model): - self._patch_modules(patch_cublaslinear, sage_attention) - model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention) + if sage_attention != "disabled": + new_attention = get_sage_func(sage_attention) + def attention_override_sage(func, *args, **kwargs): + return new_attention.__wrapped__(*args, **kwargs) + + # attention override + model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage + return model, clip, vae def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): @@ -468,9 +402,13 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): model.force_cast_weights = False print(f"Setting {model_name} compute dtype to {dtype}") - def patch_attention(model): - self._patch_modules(patch_cublaslinear, sage_attention) - model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention) + if sage_attention != "disabled": + new_attention = get_sage_func(sage_attention) + def attention_override_sage(func, *args, **kwargs): + return new_attention.__wrapped__(*args, **kwargs) + + # attention override + model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage return (model,)