Start using core "optimized_attention_override" for sageattn patches

This commit is contained in:
kijai 2025-11-05 14:05:23 +02:00
parent e64b67b8f4
commit 1585f9b523

View File

@ -1,15 +1,17 @@
import os
from comfy.ldm.modules import attention as comfy_attention
import logging
import comfy.model_patcher
import comfy.utils
import comfy.sd
import torch
import importlib
import folder_paths
import comfy.model_management as mm
from comfy.cli_args import args
from typing import Optional, Tuple
import importlib
from comfy.ldm.modules.attention import wrap_attn
import comfy.model_patcher
import comfy.utils
import comfy.sd
try:
from comfy_api.latest import io
v3_available = True
@ -32,161 +34,83 @@ if not _initialized:
pass
_initialized = True
def get_sage_func(sage_attention, allow_compile=False):
logging.info(f"Using sage attention mode: {sage_attention}")
from sageattention import sageattn
if sage_attention == "auto":
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
from sageattention import sageattn_qk_int8_pv_fp16_cuda
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
from sageattention import sageattn_qk_int8_pv_fp16_triton
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
elif "sageattn3" in sage_attention:
from sageattn3 import sageattn3_blackwell
if sage_attention == "sageattn3_per_block_mean":
def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2)
else:
def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2)
logging.info(f"Sage attention function: {sage_func}")
if not allow_compile:
sage_func = torch.compiler.disable()(sage_func)
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
return attention_sage
class BaseLoaderKJ:
original_linear = None
cublas_patched = False
@torch.compiler.disable()
def _patch_modules(self, patch_cublaslinear, sage_attention):
try:
from comfy.ldm.qwen_image.model import apply_rotary_emb
def qwen_sage_forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
img_attn_output = self.to_out[0](img_attn_output)
img_attn_output = self.to_out[1](img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
except:
print("Failed to patch QwenImage attention, Comfy not updated, skipping")
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
if sage_attention != "disabled":
print("Patching comfy attention to use sageattn")
from sageattention import sageattn
def set_sage_func(sage_attention):
if sage_attention == "auto":
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
return func
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
from sageattention import sageattn_qk_int8_pv_fp16_cuda
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
return func
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
from sageattention import sageattn_qk_int8_pv_fp16_triton
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
return func
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
return func
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
return func
elif "sageattn3" in sage_attention:
from sageattn3 import sageattn3_blackwell
if sage_attention == "sageattn3_per_block_mean":
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2)
else:
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2)
return func
sage_func = set_sage_func(sage_attention)
@torch.compiler.disable()
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, transformer_options=None):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
comfy_attention.optimized_attention = attention_sage
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
comfy.ldm.flux.math.optimized_attention = attention_sage
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
comfy.ldm.wan.model.optimized_attention = attention_sage
try:
comfy.ldm.qwen_image.model.Attention.forward = qwen_sage_forward
except:
pass
else:
print("Restoring initial comfy attention")
comfy_attention.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention")
try:
comfy.ldm.qwen_image.model.Attention.forward = _original_functions.get("original_qwen_forward")
except:
pass
if patch_cublaslinear:
if not BaseLoaderKJ.cublas_patched:
BaseLoaderKJ.original_linear = disable_weight_init.Linear
@ -218,13 +142,17 @@ class BaseLoaderKJ:
from comfy.patcher_extension import CallbacksMP
class PathchSageAttentionKJ(BaseLoaderKJ):
class PathchSageAttentionKJ():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
}}
},
"optional": {
"allow_compile": ("BOOLEAN", {"default": False, "tooltip": "Allow the use of torch.compile for the new attention function."})
}
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "patch"
@ -232,18 +160,19 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch(self, model, sage_attention):
def patch(self, model, sage_attention, allow_compile=False):
if sage_attention == "disabled":
return model,
model_clone = model.clone()
@torch.compiler.disable()
def patch_attention_enable(model):
self._patch_modules(False, sage_attention)
@torch.compiler.disable()
def patch_attention_disable(model):
self._patch_modules(False, "disabled")
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_attention_enable)
model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_attention_disable)
new_attention = get_sage_func(sage_attention, allow_compile=allow_compile)
def attention_override_sage(func, *args, **kwargs):
return new_attention.__wrapped__(*args, **kwargs)
# attention override
model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
return model_clone,
class CheckpointLoaderKJ(BaseLoaderKJ):
@ -306,9 +235,14 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
torch.backends.cuda.matmul.allow_fp16_accumulation = False
def patch_attention(model):
self._patch_modules(patch_cublaslinear, sage_attention)
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
if sage_attention != "disabled":
new_attention = get_sage_func(sage_attention)
def attention_override_sage(func, *args, **kwargs):
return new_attention.__wrapped__(*args, **kwargs)
# attention override
model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
return model, clip, vae
def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
@ -468,9 +402,13 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
model.force_cast_weights = False
print(f"Setting {model_name} compute dtype to {dtype}")
def patch_attention(model):
self._patch_modules(patch_cublaslinear, sage_attention)
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
if sage_attention != "disabled":
new_attention = get_sage_func(sage_attention)
def attention_override_sage(func, *args, **kwargs):
return new_attention.__wrapped__(*args, **kwargs)
# attention override
model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
return (model,)