mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-03-16 14:27:05 +08:00
Start using core "optimized_attention_override" for sageattn patches
This commit is contained in:
parent
e64b67b8f4
commit
1585f9b523
@ -1,15 +1,17 @@
|
||||
import os
|
||||
from comfy.ldm.modules import attention as comfy_attention
|
||||
import logging
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import torch
|
||||
import importlib
|
||||
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.cli_args import args
|
||||
from typing import Optional, Tuple
|
||||
import importlib
|
||||
from comfy.ldm.modules.attention import wrap_attn
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
|
||||
try:
|
||||
from comfy_api.latest import io
|
||||
v3_available = True
|
||||
@ -32,161 +34,83 @@ if not _initialized:
|
||||
pass
|
||||
_initialized = True
|
||||
|
||||
|
||||
def get_sage_func(sage_attention, allow_compile=False):
|
||||
logging.info(f"Using sage attention mode: {sage_attention}")
|
||||
from sageattention import sageattn
|
||||
if sage_attention == "auto":
|
||||
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
|
||||
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
|
||||
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
||||
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
|
||||
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++":
|
||||
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||
def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
|
||||
elif "sageattn3" in sage_attention:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
if sage_attention == "sageattn3_per_block_mean":
|
||||
def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2)
|
||||
else:
|
||||
def sage_func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2)
|
||||
logging.info(f"Sage attention function: {sage_func}")
|
||||
|
||||
if not allow_compile:
|
||||
sage_func = torch.compiler.disable()(sage_func)
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
tensor_layout="NHD"
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
return attention_sage
|
||||
|
||||
class BaseLoaderKJ:
|
||||
original_linear = None
|
||||
cublas_patched = False
|
||||
|
||||
@torch.compiler.disable()
|
||||
def _patch_modules(self, patch_cublaslinear, sage_attention):
|
||||
try:
|
||||
from comfy.ldm.qwen_image.model import apply_rotary_emb
|
||||
def qwen_sage_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor, # Image stream
|
||||
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
|
||||
img_attn_output = self.to_out[0](img_attn_output)
|
||||
img_attn_output = self.to_out[1](img_attn_output)
|
||||
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
except:
|
||||
print("Failed to patch QwenImage attention, Comfy not updated, skipping")
|
||||
|
||||
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
||||
|
||||
if sage_attention != "disabled":
|
||||
print("Patching comfy attention to use sageattn")
|
||||
from sageattention import sageattn
|
||||
def set_sage_func(sage_attention):
|
||||
if sage_attention == "auto":
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
||||
return func
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
|
||||
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
|
||||
return func
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
|
||||
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
||||
return func
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
|
||||
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
|
||||
return func
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++":
|
||||
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
|
||||
return func
|
||||
elif "sageattn3" in sage_attention:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
if sage_attention == "sageattn3_per_block_mean":
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2)
|
||||
else:
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2)
|
||||
return func
|
||||
|
||||
sage_func = set_sage_func(sage_attention)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, transformer_options=None):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
tensor_layout="NHD"
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
comfy_attention.optimized_attention = attention_sage
|
||||
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
|
||||
comfy.ldm.flux.math.optimized_attention = attention_sage
|
||||
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
|
||||
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
|
||||
comfy.ldm.wan.model.optimized_attention = attention_sage
|
||||
try:
|
||||
comfy.ldm.qwen_image.model.Attention.forward = qwen_sage_forward
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
print("Restoring initial comfy attention")
|
||||
comfy_attention.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention")
|
||||
try:
|
||||
comfy.ldm.qwen_image.model.Attention.forward = _original_functions.get("original_qwen_forward")
|
||||
except:
|
||||
pass
|
||||
|
||||
if patch_cublaslinear:
|
||||
if not BaseLoaderKJ.cublas_patched:
|
||||
BaseLoaderKJ.original_linear = disable_weight_init.Linear
|
||||
@ -218,13 +142,17 @@ class BaseLoaderKJ:
|
||||
|
||||
|
||||
from comfy.patcher_extension import CallbacksMP
|
||||
class PathchSageAttentionKJ(BaseLoaderKJ):
|
||||
class PathchSageAttentionKJ():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
|
||||
}}
|
||||
},
|
||||
"optional": {
|
||||
"allow_compile": ("BOOLEAN", {"default": False, "tooltip": "Allow the use of torch.compile for the new attention function."})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", )
|
||||
FUNCTION = "patch"
|
||||
@ -232,18 +160,19 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, sage_attention):
|
||||
def patch(self, model, sage_attention, allow_compile=False):
|
||||
if sage_attention == "disabled":
|
||||
return model,
|
||||
|
||||
model_clone = model.clone()
|
||||
@torch.compiler.disable()
|
||||
def patch_attention_enable(model):
|
||||
self._patch_modules(False, sage_attention)
|
||||
@torch.compiler.disable()
|
||||
def patch_attention_disable(model):
|
||||
self._patch_modules(False, "disabled")
|
||||
|
||||
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_attention_enable)
|
||||
model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_attention_disable)
|
||||
|
||||
|
||||
new_attention = get_sage_func(sage_attention, allow_compile=allow_compile)
|
||||
def attention_override_sage(func, *args, **kwargs):
|
||||
return new_attention.__wrapped__(*args, **kwargs)
|
||||
|
||||
# attention override
|
||||
model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
|
||||
|
||||
return model_clone,
|
||||
|
||||
class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
@ -306,9 +235,14 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = False
|
||||
|
||||
def patch_attention(model):
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||
if sage_attention != "disabled":
|
||||
new_attention = get_sage_func(sage_attention)
|
||||
def attention_override_sage(func, *args, **kwargs):
|
||||
return new_attention.__wrapped__(*args, **kwargs)
|
||||
|
||||
# attention override
|
||||
model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
|
||||
|
||||
return model, clip, vae
|
||||
|
||||
def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||
@ -468,9 +402,13 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
model.force_cast_weights = False
|
||||
print(f"Setting {model_name} compute dtype to {dtype}")
|
||||
|
||||
def patch_attention(model):
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||
if sage_attention != "disabled":
|
||||
new_attention = get_sage_func(sage_attention)
|
||||
def attention_override_sage(func, *args, **kwargs):
|
||||
return new_attention.__wrapped__(*args, **kwargs)
|
||||
|
||||
# attention override
|
||||
model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user