From eebdc412f910cef06bcd43d84fe4a164d37e4e79 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 17 Nov 2024 21:43:53 +0200 Subject: [PATCH] fix sageattention --- custom_cogvideox_transformer_3d.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 20615be..ba9f037 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -121,10 +121,7 @@ class CogVideoXAttnProcessor2_0: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) if attention_mode == "sageattn": - if SAGEATTN_IS_AVAILABLE: - hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) - else: - raise ImportError("sageattn not found") + hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -198,7 +195,6 @@ class CogVideoXBlock(nn.Module): ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, - attention_mode: Optional[str] = None, ): super().__init__() @@ -230,7 +226,6 @@ class CogVideoXBlock(nn.Module): ) self.cached_hidden_states = [] self.cached_encoder_hidden_states = [] - self.attention_mode = attention_mode def forward( self, @@ -244,6 +239,7 @@ class CogVideoXBlock(nn.Module): fastercache_counter=0, fastercache_start_step=15, fastercache_device="cuda:0", + attention_mode="sdpa", ) -> torch.Tensor: #print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072]) text_seq_length = encoder_hidden_states.size(1) @@ -282,7 +278,7 @@ class CogVideoXBlock(nn.Module): hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, - attention_mode=self.attention_mode, + attention_mode=attention_mode, ) if fastercache_counter == fastercache_start_step: self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] @@ -295,7 +291,7 @@ class CogVideoXBlock(nn.Module): hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, - attention_mode=self.attention_mode, + attention_mode=attention_mode, ) hidden_states = hidden_states + gate_msa * attn_hidden_states @@ -404,7 +400,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, patch_bias: bool = True, - attention_mode: Optional[str] = None, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -493,7 +488,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.fastercache_hf_step = 30 self.fastercache_device = "cuda" self.fastercache_num_blocks_to_cache = len(self.transformer_blocks) - self.attention_mode = attention_mode + self.attention_mode = "sdpa" + def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value @@ -620,7 +616,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, fastercache_start_step = self.fastercache_start_step, - fastercache_device = self.fastercache_device + fastercache_device = self.fastercache_device, + attention_mode = self.attention_mode ) if (controlnet_states is not None) and (i < len(controlnet_states)): @@ -690,7 +687,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, fastercache_start_step = self.fastercache_start_step, - fastercache_device = self.fastercache_device + fastercache_device = self.fastercache_device, + attention_mode = self.attention_mode ) #has_nan = torch.isnan(hidden_states).any() #if has_nan: