Merge pull request #261 from Dango233/Dango233-patch-1

Fix fused sdpa
This commit is contained in:
Jukka Seppänen 2024-11-20 12:37:26 +02:00 committed by GitHub
commit b9f7b6e338
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -128,7 +128,7 @@ class CogVideoXAttnProcessor2_0:
if attention_mode == "sageattn" or attention_mode == "fused_sageattn": if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "sdpa": elif attention_mode == "sdpa" or attention_mode == "fused_sdpa":
hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
) )
@ -751,4 +751,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
if not return_dict: if not return_dict:
return (output,) return (output,)
return Transformer2DModelOutput(sample=output) return Transformer2DModelOutput(sample=output)