From b31a0256739be10cbd3af890d6d552aba22b9241 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Wed, 20 Nov 2024 17:40:28 +0800 Subject: [PATCH] Fix fused sdpa --- custom_cogvideox_transformer_3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 05bed34..27f2d5f 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -128,7 +128,7 @@ class CogVideoXAttnProcessor2_0: if attention_mode == "sageattn" or attention_mode == "fused_sageattn": hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - elif attention_mode == "sdpa": + elif attention_mode == "sdpa" or attention_mode == "fused_sdpa": hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -751,4 +751,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - \ No newline at end of file +