diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 7f0cf3b..9499094 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -52,7 +52,7 @@ class CogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, @@ -126,7 +126,7 @@ class FusedCogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, diff --git a/nodes.py b/nodes.py index d279033..5f74b48 100644 --- a/nodes.py +++ b/nodes.py @@ -408,7 +408,11 @@ class DownloadAndLoadCogVideoModel: if compile == "torch": torch._dynamo.config.suppress_errors = True pipe.transformer.to(memory_format=torch.channels_last) - pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") + #pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") + for i, block in enumerate(pipe.transformer.transformer_blocks): + if "CogVideoXBlock" in str(block): + print(block) + pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") elif compile == "onediff": from onediffx import compile_pipe os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' @@ -559,10 +563,8 @@ class DownloadAndLoadCogVideoGGUFModel: convert_fp8_linear(transformer, vae_dtype) # compilation - if compile == "torch": - torch._dynamo.config.suppress_errors = True - pipe.transformer.to(memory_format=torch.channels_last) - pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") + for i, block in enumerate(transformer.transformer_blocks): + transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") with open(scheduler_path) as f: scheduler_config = json.load(f) diff --git a/videosys/cogvideox_transformer_3d.py b/videosys/cogvideox_transformer_3d.py index 6a482fa..b0e1aa5 100644 --- a/videosys/cogvideox_transformer_3d.py +++ b/videosys/cogvideox_transformer_3d.py @@ -9,7 +9,7 @@ # -------------------------------------------------------- from typing import Any, Dict, Optional, Tuple, Union - +from einops import rearrange import torch import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -42,7 +42,7 @@ class CogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, @@ -134,7 +134,7 @@ class FusedCogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, @@ -286,7 +286,7 @@ class CogVideoXBlock(nn.Module): self.attn_count = 0 self.last_attn = None self.block_idx = block_idx - + #@torch.compiler.disable() def forward( self, hidden_states: torch.Tensor, @@ -294,6 +294,8 @@ class CogVideoXBlock(nn.Module): temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, timestep=None, + video_flow_feature: Optional[torch.Tensor] = None, + fuser=None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -301,7 +303,14 @@ class CogVideoXBlock(nn.Module): norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) - + # Tora Motion-guidance Fuser + if video_flow_feature is not None: + H, W = video_flow_feature.shape[-2:] + T = norm_hidden_states.shape[1] // H // W + h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W) + h = fuser(h, video_flow_feature.to(h), T=T) + norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) + del h, fuser # attention if enable_pab(): broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx) @@ -494,6 +503,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False + self.fuser_list = None + # parallel #self.parallel_manager = None @@ -524,6 +535,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): return_dict: bool = True, controlnet_states: torch.Tensor = None, controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0, + video_flow_features: Optional[torch.Tensor] = None, ): # if self.parallel_manager.cp_size > 1: # ( @@ -574,31 +586,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): # 4. Transformer blocks for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - timestep=timesteps if enable_pab() else None, - ) + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + timestep=timesteps if enable_pab() else None, + video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, + fuser = self.fuser_list[i] if self.fuser_list is not None else None, + ) if (controlnet_states is not None) and (i < len(controlnet_states)): controlnet_states_block = controlnet_states[i] controlnet_block_weight = 1.0