diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 2fb8f57..1f27778 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -7,6 +7,8 @@ import torch import folder_paths import comfy.model_management as mm from comfy.cli_args import args +from typing import Optional, Tuple + sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++"] @@ -17,6 +19,7 @@ if not _initialized: _original_functions["orig_attention"] = comfy_attention.optimized_attention _original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model _original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models + _original_functions["original_qwen_forward"] = comfy.ldm.qwen_image.model.Attention.forward _initialized = True class BaseLoaderKJ: @@ -25,6 +28,55 @@ class BaseLoaderKJ: @torch.compiler.disable() def _patch_modules(self, patch_cublaslinear, sage_attention): + try: + from comfy.ldm.qwen_image.model import apply_rotary_emb + def qwen_sage_forward( + self, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_txt = encoder_hidden_states.shape[1] + + img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) + img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) + img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) + + txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + joint_query = apply_rotary_emb(joint_query, image_rotary_emb) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + + joint_query = joint_query.flatten(start_dim=2) + joint_key = joint_key.flatten(start_dim=2) + joint_value = joint_value.flatten(start_dim=2) + + joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask) + + txt_attn_output = joint_hidden_states[:, :seq_txt, :] + img_attn_output = joint_hidden_states[:, seq_txt:, :] + + img_attn_output = self.to_out[0](img_attn_output) + img_attn_output = self.to_out[1](img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + except: + print("Failed to patch QwenImage attention, Comfy not updated, skipping") + from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight if sage_attention != "disabled": @@ -97,6 +149,10 @@ class BaseLoaderKJ: comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage comfy.ldm.cosmos.blocks.optimized_attention = attention_sage comfy.ldm.wan.model.optimized_attention = attention_sage + try: + comfy.ldm.qwen_image.model.Attention.forward = qwen_sage_forward + except: + pass else: print("Restoring initial comfy attention") @@ -106,6 +162,10 @@ class BaseLoaderKJ: comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention") comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention") comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention") + try: + comfy.ldm.qwen_image.model.Attention.forward = _original_functions.get("original_qwen_forward") + except: + pass if patch_cublaslinear: if not BaseLoaderKJ.cublas_patched: @@ -135,6 +195,7 @@ class BaseLoaderKJ: if BaseLoaderKJ.cublas_patched: disable_weight_init.Linear = BaseLoaderKJ.original_linear BaseLoaderKJ.cublas_patched = False + from comfy.patcher_extension import CallbacksMP class PathchSageAttentionKJ(BaseLoaderKJ):