From 729a6485ea617f578a1f324e21b2ccac42bd1ce8 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 1 Dec 2024 18:45:10 +0200 Subject: [PATCH] expose sageattn 2.0.0 functions _cuda versions seem to be required on RTX 30xx -series GPUs for sageattn + CogVideoX 1.5 --- custom_cogvideox_transformer_3d.py | 144 +++++++++++------------------ model_loading.py | 52 +++++++++-- 2 files changed, 96 insertions(+), 100 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 7401cd9..4cb64a6 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -46,9 +46,40 @@ except: from comfy.ldm.modules.attention import optimized_attention -@torch.compiler.disable() -def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False): - return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal) + +def set_attention_func(attention_mode, heads): + if attention_mode == "sdpa" or attention_mode == "fused_sdpa": + def func(q, k, v, is_causal=False, attn_mask=None): + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal) + return func + elif attention_mode == "comfy": + def func(q, k, v, is_causal=False, attn_mask=None): + return optimized_attention(q, k, v, mask=attn_mask, heads=heads, skip_reshape=True) + return func + + elif attention_mode == "sageattn" or attention_mode == "fused_sageattn": + @torch.compiler.disable() + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + return func + elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda": + from sageattention import sageattn_qk_int8_pv_fp16_cuda + @torch.compiler.disable() + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32") + return func + elif attention_mode == "sageattn_qk_int8_pv_fp16_triton": + from sageattention import sageattn_qk_int8_pv_fp16_triton + @torch.compiler.disable() + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + return func + elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda": + from sageattention import sageattn_qk_int8_pv_fp8_cuda + @torch.compiler.disable() + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") + return func def fft(tensor): tensor_fft = torch.fft.fft2(tensor) @@ -67,16 +98,18 @@ def fft(tensor): return low_freq_fft, high_freq_fft +#region Attention class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on query and key vectors, but does not include spatial normalization. """ - def __init__(self): + def __init__(self, attn_func, attention_mode: Optional[str] = None): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + self.attention_mode = attention_mode + self.attn_func = attn_func def __call__( self, attn: Attention, @@ -84,7 +117,6 @@ class CogVideoXAttnProcessor2_0: encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - attention_mode: Optional[str] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -101,7 +133,7 @@ class CogVideoXAttnProcessor2_0: if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16: hidden_states = hidden_states.to(attn.to_q.weight.dtype) - if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn": + if not "fused" in self.attention_mode: query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) @@ -128,16 +160,10 @@ class CogVideoXAttnProcessor2_0: if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - if attention_mode == "sageattn" or attention_mode == "fused_sageattn": - hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) + hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False) + + if self.attention_mode != "comfy": hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - elif attention_mode == "sdpa" or attention_mode == "fused_sdpa": - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - elif attention_mode == "comfy": - hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True) # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -203,13 +229,15 @@ class CogVideoXBlock(nn.Module): ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, + attention_mode: Optional[str] = "sdpa", ): super().__init__() # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - + attn_func = set_attention_func(attention_mode, num_attention_heads) + self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, @@ -218,7 +246,7 @@ class CogVideoXBlock(nn.Module): eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), + processor=CogVideoXAttnProcessor2_0(attn_func, attention_mode=attention_mode), ) # 2. Feed Forward @@ -247,7 +275,6 @@ class CogVideoXBlock(nn.Module): fastercache_counter=0, fastercache_start_step=15, fastercache_device="cuda:0", - attention_mode="sdpa", ) -> torch.Tensor: #print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072]) text_seq_length = encoder_hidden_states.size(1) @@ -286,7 +313,6 @@ class CogVideoXBlock(nn.Module): hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, - attention_mode=attention_mode, ) if fastercache_counter == fastercache_start_step: self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] @@ -298,8 +324,7 @@ class CogVideoXBlock(nn.Module): attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mode=attention_mode, + image_rotary_emb=image_rotary_emb ) hidden_states = hidden_states + gate_msa * attn_hidden_states @@ -408,6 +433,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, patch_bias: bool = True, + attention_mode: Optional[str] = "sdpa", ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -461,6 +487,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): dropout=dropout, activation_fn=activation_fn, attention_bias=attention_bias, + attention_mode=attention_mode, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) @@ -496,73 +523,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.fastercache_hf_step = 30 self.fastercache_device = "cuda" self.fastercache_num_blocks_to_cache = len(self.transformer_blocks) - self.attention_mode = "sdpa" + self.attention_mode = attention_mode def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - + #region forward def forward( self, hidden_states: torch.Tensor, @@ -624,8 +590,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, fastercache_start_step = self.fastercache_start_step, - fastercache_device = self.fastercache_device, - attention_mode = self.attention_mode + fastercache_device = self.fastercache_device ) if (controlnet_states is not None) and (i < len(controlnet_states)): @@ -695,8 +660,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, fastercache_start_step = self.fastercache_start_step, - fastercache_device = self.fastercache_device, - attention_mode = self.attention_mode + fastercache_device = self.fastercache_device ) #has_nan = torch.isnan(hidden_states).any() #if has_nan: @@ -754,4 +718,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - + \ No newline at end of file diff --git a/model_loading.py b/model_loading.py index 4245c64..436212e 100644 --- a/model_loading.py +++ b/model_loading.py @@ -70,6 +70,7 @@ class CogVideoLoraSelect: RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "CogVideoWrapper" + DESCRIPTION = "Select a LoRA model from ComfyUI/models/CogVideo/loras" def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False): cog_loras_list = [] @@ -93,7 +94,7 @@ class CogVideoLoraSelectComfy: return { "required": { "lora": (folder_paths.get_filename_list("loras"), - {"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}), + {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), }, "optional": { @@ -106,6 +107,7 @@ class CogVideoLoraSelectComfy: RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "CogVideoWrapper" + DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False): cog_loras_list = [] @@ -160,7 +162,19 @@ class DownloadAndLoadCogVideoModel: "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), - "attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}), + "attention_mode": ([ + "sdpa", + "fused_sdpa", + "sageattn", + "fused_sageattn", + "sageattn_qk_int8_pv_fp8_cuda", + "sageattn_qk_int8_pv_fp16_cuda", + "sageattn_qk_int8_pv_fp16_triton", + "fused_sageattn_qk_int8_pv_fp8_cuda", + "fused_sageattn_qk_int8_pv_fp16_cuda", + "fused_sageattn_qk_int8_pv_fp16_triton", + "comfy" + ], {"default": "sdpa"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), } } @@ -175,11 +189,18 @@ class DownloadAndLoadCogVideoModel: enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None, attention_mode="sdpa", load_device="main_device"): + transformer = None + if "sage" in attention_mode: try: from sageattention import sageattn except Exception as e: raise ValueError(f"Can't import SageAttention: {str(e)}") + if "qk_int8" in attention_mode: + try: + from sageattention import sageattn_qk_int8_pv_fp16_cuda + except Exception as e: + raise ValueError(f"Can't import SageAttention 2.0.0: {str(e)}") if precision == "fp16" and "1.5" in model: raise ValueError("1.5 models do not currently work in fp16") @@ -254,7 +275,7 @@ class DownloadAndLoadCogVideoModel: local_dir_use_symlinks=False, ) - transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) + transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode) transformer = transformer.to(dtype).to(transformer_load_device) if "1.5" in model: @@ -327,7 +348,6 @@ class DownloadAndLoadCogVideoModel: for module in pipe.transformer.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) - pipe.transformer.attention_mode = attention_mode if compile_args is not None: pipe.transformer.to(memory_format=torch.channels_last) @@ -551,7 +571,7 @@ class DownloadAndLoadCogVideoGGUFModel: else: transformer_config["in_channels"] = 16 - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode) cast_dtype = vae_dtype params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"} if "2b" in model: @@ -655,7 +675,19 @@ class CogVideoXModelLoader: "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), - "attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}), + "attention_mode": ([ + "sdpa", + "fused_sdpa", + "sageattn", + "fused_sageattn", + "sageattn_qk_int8_pv_fp8_cuda", + "sageattn_qk_int8_pv_fp16_cuda", + "sageattn_qk_int8_pv_fp16_triton", + "fused_sageattn_qk_int8_pv_fp8_cuda", + "fused_sageattn_qk_int8_pv_fp16_cuda", + "fused_sageattn_qk_int8_pv_fp16_triton", + "comfy" + ], {"default": "sdpa"}), } } @@ -666,7 +698,7 @@ class CogVideoXModelLoader: def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload, block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"): - + transformer = None if "sage" in attention_mode: try: from sageattention import sageattn @@ -732,7 +764,7 @@ class CogVideoXModelLoader: transformer_config["sample_width"] = 300 with init_empty_weights(): - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode) #load weights #params_to_keep = {} @@ -1084,7 +1116,7 @@ NODE_CLASS_MAPPINGS = { "CogVideoLoraSelect": CogVideoLoraSelect, "CogVideoXVAELoader": CogVideoXVAELoader, "CogVideoXModelLoader": CogVideoXModelLoader, - "CogVideoLoraSelectComfy": CogVideoLoraSelectComfy, + "CogVideoLoraSelectComfy": CogVideoLoraSelectComfy } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1094,5 +1126,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoLoraSelect": "CogVideo LoraSelect", "CogVideoXVAELoader": "CogVideoX VAE Loader", "CogVideoXModelLoader": "CogVideoX Model Loader", - "CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy", + "CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy" } \ No newline at end of file