Support Qwen image with sageattn patch

This commit is contained in:
kijai 2025-08-05 23:52:54 +03:00
parent fbdb08f9d6
commit d382efd7e7

View File

@ -7,6 +7,8 @@ import torch
import folder_paths
import comfy.model_management as mm
from comfy.cli_args import args
from typing import Optional, Tuple
sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++"]
@ -17,6 +19,7 @@ if not _initialized:
_original_functions["orig_attention"] = comfy_attention.optimized_attention
_original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model
_original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models
_original_functions["original_qwen_forward"] = comfy.ldm.qwen_image.model.Attention.forward
_initialized = True
class BaseLoaderKJ:
@ -25,6 +28,55 @@ class BaseLoaderKJ:
@torch.compiler.disable()
def _patch_modules(self, patch_cublaslinear, sage_attention):
try:
from comfy.ldm.qwen_image.model import apply_rotary_emb
def qwen_sage_forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
img_attn_output = self.to_out[0](img_attn_output)
img_attn_output = self.to_out[1](img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
except:
print("Failed to patch QwenImage attention, Comfy not updated, skipping")
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
if sage_attention != "disabled":
@ -97,6 +149,10 @@ class BaseLoaderKJ:
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
comfy.ldm.wan.model.optimized_attention = attention_sage
try:
comfy.ldm.qwen_image.model.Attention.forward = qwen_sage_forward
except:
pass
else:
print("Restoring initial comfy attention")
@ -106,6 +162,10 @@ class BaseLoaderKJ:
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention")
try:
comfy.ldm.qwen_image.model.Attention.forward = _original_functions.get("original_qwen_forward")
except:
pass
if patch_cublaslinear:
if not BaseLoaderKJ.cublas_patched:
@ -135,6 +195,7 @@ class BaseLoaderKJ:
if BaseLoaderKJ.cublas_patched:
disable_weight_init.Linear = BaseLoaderKJ.original_linear
BaseLoaderKJ.cublas_patched = False
from comfy.patcher_extension import CallbacksMP
class PathchSageAttentionKJ(BaseLoaderKJ):