diff --git a/cogvideox_fun/transformer_3d.py b/cogvideox_fun/transformer_3d.py index 83614e2..5b6fef9 100644 --- a/cogvideox_fun/transformer_3d.py +++ b/cogvideox_fun/transformer_3d.py @@ -37,11 +37,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name from einops import rearrange try: from sageattention import sageattn - SAGEATTN_IS_AVAVILABLE = True - logger.info("Using sageattn") + SAGEATTN_IS_AVAILABLE = True except: - logger.info("sageattn not found, using sdpa") - SAGEATTN_IS_AVAVILABLE = False + SAGEATTN_IS_AVAILABLE = False def fft(tensor): tensor_fft = torch.fft.fft2(tensor) @@ -77,6 +75,7 @@ class CogVideoXAttnProcessor2_0: encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + attention_mode: Optional[str] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -113,83 +112,12 @@ class CogVideoXAttnProcessor2_0: query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - if SAGEATTN_IS_AVAVILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - - -class FusedCogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - from diffusers.models.embeddings import apply_rotary_emb - - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - if SAGEATTN_IS_AVAVILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) + + if attention_mode == "sageattn": + if SAGEATTN_IS_AVAILABLE: + hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) + else: + raise ImportError("sageattn not found") else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -298,6 +226,7 @@ class CogVideoXBlock(nn.Module): ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, + attention_mode: Optional[str] = None, ): super().__init__() @@ -326,7 +255,10 @@ class CogVideoXBlock(nn.Module): inner_dim=ff_inner_dim, bias=ff_bias, ) - + self.cached_hidden_states = [] + self.cached_encoder_hidden_states = [] + self.attention_mode = attention_mode + def forward( self, hidden_states: torch.Tensor, @@ -335,6 +267,7 @@ class CogVideoXBlock(nn.Module): image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, video_flow_feature: Optional[torch.Tensor] = None, fuser=None, + block_use_fastercache=False, fastercache_counter=0, fastercache_start_step=15, fastercache_device="cuda:0", @@ -352,33 +285,43 @@ class CogVideoXBlock(nn.Module): h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W) h = fuser(h, video_flow_feature.to(h), T=T) norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) - del h, fuser - #fastercache - B = norm_hidden_states.shape[0] - if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B: - attn_hidden_states = ( - self.cached_hidden_states[1][:B] + - (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B]) - * 0.3 - ).to(norm_hidden_states.device, non_blocking=True) - attn_encoder_hidden_states = ( - self.cached_encoder_hidden_states[1][:B] + - (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B]) - * 0.3 - ).to(norm_hidden_states.device, non_blocking=True) + del h, fuser + + #region fastercache + if block_use_fastercache: + B = norm_hidden_states.shape[0] + if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B: + attn_hidden_states = ( + self.cached_hidden_states[1][:B] + + (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B]) + * 0.3 + ).to(norm_hidden_states.device, non_blocking=True) + attn_encoder_hidden_states = ( + self.cached_encoder_hidden_states[1][:B] + + (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B]) + * 0.3 + ).to(norm_hidden_states.device, non_blocking=True) + else: + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mode=self.attention_mode, + ) + if fastercache_counter == fastercache_start_step: + self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] + self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)] + elif fastercache_counter > fastercache_start_step: + self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) + self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) else: attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mode=self.attention_mode, ) - if fastercache_counter == fastercache_start_step: - self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] - self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)] - elif fastercache_counter > fastercache_start_step: - self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) - self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) - + hidden_states = hidden_states + gate_msa * attn_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states @@ -481,6 +424,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, add_noise_in_inpaint_model: bool = False, + attention_mode: Optional[str] = None, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -554,6 +498,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): self.fastercache_lf_step = 40 self.fastercache_hf_step = 30 self.fastercache_device = "cuda" + self.fastercache_num_blocks_to_cache = len(self.transformer_blocks) + self.attention_mode = attention_mode def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value @@ -720,6 +666,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): image_rotary_emb=image_rotary_emb, video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, + block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, + fastercache_start_step = self.fastercache_start_step, fastercache_counter = self.fastercache_counter, fastercache_device = self.fastercache_device ) @@ -770,7 +718,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): image_rotary_emb=image_rotary_emb, video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, + block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, + fastercache_start_step = self.fastercache_start_step, fastercache_device = self.fastercache_device ) diff --git a/nodes.py b/nodes.py index 817683f..f2874cd 100644 --- a/nodes.py +++ b/nodes.py @@ -1180,6 +1180,8 @@ class CogVideoXFunSampler: pipe.transformer.fastercache_lf_step = fastercache["lf_step"] pipe.transformer.fastercache_hf_step = fastercache["hf_step"] pipe.transformer.fastercache_device = fastercache["cache_device"] + pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"] + log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}") else: pipe.transformer.use_fastercache = False pipe.transformer.fastercache_counter = 0 @@ -1187,7 +1189,7 @@ class CogVideoXFunSampler: generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 - autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() + autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() with autocast_context: video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 if vid2vid_images is not None: diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 13c960e..87d19e9 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -472,8 +472,15 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): # 5. Prepare latents. latent_channels = self.vae.config.latent_channels latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t - patch_size_t = self.transformer.config.patch_size_t + patch_size_t = getattr(self.transformer.config, "patch_size_t", None) + if patch_size_t is None: + self.transformer.config.patch_size_t = None + ofs_embed_dim = getattr(self.transformer.config, "ofs_embed_dim", None) + if ofs_embed_dim is None: + self.transformer.config.ofs_embed_dim = None + self.additional_frames = 0 if patch_size_t is not None and latent_frames % patch_size_t != 0: self.additional_frames = patch_size_t - latent_frames % patch_size_t