fix sageattention

This commit is contained in:
kijai 2024-11-17 21:43:53 +02:00
parent 15aa68c95d
commit eebdc412f9

View File

@ -121,10 +121,7 @@ class CogVideoXAttnProcessor2_0:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn":
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
else:
raise ImportError("sageattn not found")
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
@ -198,7 +195,6 @@ class CogVideoXBlock(nn.Module):
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
attention_mode: Optional[str] = None,
):
super().__init__()
@ -230,7 +226,6 @@ class CogVideoXBlock(nn.Module):
)
self.cached_hidden_states = []
self.cached_encoder_hidden_states = []
self.attention_mode = attention_mode
def forward(
self,
@ -244,6 +239,7 @@ class CogVideoXBlock(nn.Module):
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
attention_mode="sdpa",
) -> torch.Tensor:
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
text_seq_length = encoder_hidden_states.size(1)
@ -282,7 +278,7 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode,
attention_mode=attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
@ -295,7 +291,7 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode,
attention_mode=attention_mode,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
@ -404,7 +400,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
attention_mode: Optional[str] = None,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -493,7 +488,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = attention_mode
self.attention_mode = "sdpa"
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@ -620,7 +616,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
@ -690,7 +687,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
)
#has_nan = torch.isnan(hidden_states).any()
#if has_nan: