expose sageattn 2.0.0 functions

_cuda versions seem to be required on RTX 30xx -series GPUs for sageattn + CogVideoX 1.5
This commit is contained in:
kijai 2024-12-01 18:45:10 +02:00
parent 411791c748
commit 729a6485ea
2 changed files with 96 additions and 100 deletions

View File

@ -46,9 +46,40 @@ except:
from comfy.ldm.modules.attention import optimized_attention
@torch.compiler.disable()
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
def set_attention_func(attention_mode, heads):
if attention_mode == "sdpa" or attention_mode == "fused_sdpa":
def func(q, k, v, is_causal=False, attn_mask=None):
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)
return func
elif attention_mode == "comfy":
def func(q, k, v, is_causal=False, attn_mask=None):
return optimized_attention(q, k, v, mask=attn_mask, heads=heads, skip_reshape=True)
return func
elif attention_mode == "sageattn" or attention_mode == "fused_sageattn":
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda":
from sageattention import sageattn_qk_int8_pv_fp16_cuda
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_triton":
from sageattention import sageattn_qk_int8_pv_fp16_triton
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
return func
elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
return func
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
@ -67,16 +98,18 @@ def fft(tensor):
return low_freq_fft, high_freq_fft
#region Attention
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
def __init__(self, attn_func, attention_mode: Optional[str] = None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.attention_mode = attention_mode
self.attn_func = attn_func
def __call__(
self,
attn: Attention,
@ -84,7 +117,6 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mode: Optional[str] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@ -101,7 +133,7 @@ class CogVideoXAttnProcessor2_0:
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
if not "fused" in self.attention_mode:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
@ -128,16 +160,10 @@ class CogVideoXAttnProcessor2_0:
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
if self.attention_mode != "comfy":
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "sdpa" or attention_mode == "fused_sdpa":
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "comfy":
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@ -203,13 +229,15 @@ class CogVideoXBlock(nn.Module):
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
attention_mode: Optional[str] = "sdpa",
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
attn_func = set_attention_func(attention_mode, num_attention_heads)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
@ -218,7 +246,7 @@ class CogVideoXBlock(nn.Module):
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
processor=CogVideoXAttnProcessor2_0(attn_func, attention_mode=attention_mode),
)
# 2. Feed Forward
@ -247,7 +275,6 @@ class CogVideoXBlock(nn.Module):
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
attention_mode="sdpa",
) -> torch.Tensor:
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
text_seq_length = encoder_hidden_states.size(1)
@ -286,7 +313,6 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
@ -298,8 +324,7 @@ class CogVideoXBlock(nn.Module):
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=attention_mode,
image_rotary_emb=image_rotary_emb
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
@ -408,6 +433,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
attention_mode: Optional[str] = "sdpa",
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -461,6 +487,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
attention_mode=attention_mode,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
@ -496,73 +523,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = "sdpa"
self.attention_mode = attention_mode
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
#region forward
def forward(
self,
hidden_states: torch.Tensor,
@ -624,8 +590,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
fastercache_device = self.fastercache_device
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
@ -695,8 +660,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
fastercache_device = self.fastercache_device
)
#has_nan = torch.isnan(hidden_states).any()
#if has_nan:
@ -754,4 +718,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -70,6 +70,7 @@ class CogVideoLoraSelect:
RETURN_NAMES = ("lora", )
FUNCTION = "getlorapath"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Select a LoRA model from ComfyUI/models/CogVideo/loras"
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
cog_loras_list = []
@ -93,7 +94,7 @@ class CogVideoLoraSelectComfy:
return {
"required": {
"lora": (folder_paths.get_filename_list("loras"),
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
},
"optional": {
@ -106,6 +107,7 @@ class CogVideoLoraSelectComfy:
RETURN_NAMES = ("lora", )
FUNCTION = "getlorapath"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras"
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
cog_loras_list = []
@ -160,7 +162,19 @@ class DownloadAndLoadCogVideoModel:
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}),
"attention_mode": ([
"sdpa",
"fused_sdpa",
"sageattn",
"fused_sageattn",
"sageattn_qk_int8_pv_fp8_cuda",
"sageattn_qk_int8_pv_fp16_cuda",
"sageattn_qk_int8_pv_fp16_triton",
"fused_sageattn_qk_int8_pv_fp8_cuda",
"fused_sageattn_qk_int8_pv_fp16_cuda",
"fused_sageattn_qk_int8_pv_fp16_triton",
"comfy"
], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
}
}
@ -175,11 +189,18 @@ class DownloadAndLoadCogVideoModel:
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
transformer = None
if "sage" in attention_mode:
try:
from sageattention import sageattn
except Exception as e:
raise ValueError(f"Can't import SageAttention: {str(e)}")
if "qk_int8" in attention_mode:
try:
from sageattention import sageattn_qk_int8_pv_fp16_cuda
except Exception as e:
raise ValueError(f"Can't import SageAttention 2.0.0: {str(e)}")
if precision == "fp16" and "1.5" in model:
raise ValueError("1.5 models do not currently work in fp16")
@ -254,7 +275,7 @@ class DownloadAndLoadCogVideoModel:
local_dir_use_symlinks=False,
)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
transformer = transformer.to(dtype).to(transformer_load_device)
if "1.5" in model:
@ -327,7 +348,6 @@ class DownloadAndLoadCogVideoModel:
for module in pipe.transformer.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
pipe.transformer.attention_mode = attention_mode
if compile_args is not None:
pipe.transformer.to(memory_format=torch.channels_last)
@ -551,7 +571,7 @@ class DownloadAndLoadCogVideoGGUFModel:
else:
transformer_config["in_channels"] = 16
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
cast_dtype = vae_dtype
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
if "2b" in model:
@ -655,7 +675,19 @@ class CogVideoXModelLoader:
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
"attention_mode": ([
"sdpa",
"fused_sdpa",
"sageattn",
"fused_sageattn",
"sageattn_qk_int8_pv_fp8_cuda",
"sageattn_qk_int8_pv_fp16_cuda",
"sageattn_qk_int8_pv_fp16_triton",
"fused_sageattn_qk_int8_pv_fp8_cuda",
"fused_sageattn_qk_int8_pv_fp16_cuda",
"fused_sageattn_qk_int8_pv_fp16_triton",
"comfy"
], {"default": "sdpa"}),
}
}
@ -666,7 +698,7 @@ class CogVideoXModelLoader:
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
transformer = None
if "sage" in attention_mode:
try:
from sageattention import sageattn
@ -732,7 +764,7 @@ class CogVideoXModelLoader:
transformer_config["sample_width"] = 300
with init_empty_weights():
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
#load weights
#params_to_keep = {}
@ -1084,7 +1116,7 @@ NODE_CLASS_MAPPINGS = {
"CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoXVAELoader": CogVideoXVAELoader,
"CogVideoXModelLoader": CogVideoXModelLoader,
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy,
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1094,5 +1126,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoXVAELoader": "CogVideoX VAE Loader",
"CogVideoXModelLoader": "CogVideoX Model Loader",
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy",
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy"
}