mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
fix sageattention
This commit is contained in:
parent
15aa68c95d
commit
eebdc412f9
@ -121,10 +121,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if attention_mode == "sageattn":
|
||||
if SAGEATTN_IS_AVAILABLE:
|
||||
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
else:
|
||||
raise ImportError("sageattn not found")
|
||||
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
@ -198,7 +195,6 @@ class CogVideoXBlock(nn.Module):
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
attention_mode: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -230,7 +226,6 @@ class CogVideoXBlock(nn.Module):
|
||||
)
|
||||
self.cached_hidden_states = []
|
||||
self.cached_encoder_hidden_states = []
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -244,6 +239,7 @@ class CogVideoXBlock(nn.Module):
|
||||
fastercache_counter=0,
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0",
|
||||
attention_mode="sdpa",
|
||||
) -> torch.Tensor:
|
||||
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
@ -282,7 +278,7 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mode=self.attention_mode,
|
||||
attention_mode=attention_mode,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||
@ -295,7 +291,7 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mode=self.attention_mode,
|
||||
attention_mode=attention_mode,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
@ -404,7 +400,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
use_learned_positional_embeddings: bool = False,
|
||||
patch_bias: bool = True,
|
||||
attention_mode: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@ -493,7 +488,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.fastercache_hf_step = 30
|
||||
self.fastercache_device = "cuda"
|
||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||
self.attention_mode = attention_mode
|
||||
self.attention_mode = "sdpa"
|
||||
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
@ -620,7 +616,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
fastercache_device = self.fastercache_device,
|
||||
attention_mode = self.attention_mode
|
||||
)
|
||||
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
@ -690,7 +687,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
fastercache_device = self.fastercache_device,
|
||||
attention_mode = self.attention_mode
|
||||
)
|
||||
#has_nan = torch.isnan(hidden_states).any()
|
||||
#if has_nan:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user