diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 545e084..85687fe 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -528,6 +528,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): context_stride: Optional[int] = None, context_overlap: Optional[int] = None, freenoise: Optional[bool] = True, + tora: Optional[dict] = None, ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -720,7 +721,13 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): if self.transformer.config.use_rotary_positional_embeddings else None ) - + if tora is not None and do_classifier_free_guidance: + video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous() + + if tora is not None: + for module in self.transformer.fuser_list: + for param in module.parameters(): + param.data = param.data.to(device) with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ @@ -910,6 +917,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): image_rotary_emb=image_rotary_emb, return_dict=False, control_latents=current_control_latents, + video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, + )[0] noise_pred = noise_pred.float() diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index 459e845..1ee3920 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -610,6 +610,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): context_stride: Optional[int] = None, context_overlap: Optional[int] = None, freenoise: Optional[bool] = True, + tora: Optional[dict] = None, ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -889,6 +890,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): if self.transformer.config.use_rotary_positional_embeddings else None ) + if tora is not None and do_classifier_free_guidance: + video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous() + + if tora is not None: + for module in self.transformer.fuser_list: + for param in module.parameters(): + param.data = param.data.to(device) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1061,6 +1069,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + current_step_percentage = i / num_inference_steps + # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input, @@ -1069,6 +1079,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): image_rotary_emb=image_rotary_emb, return_dict=False, inpaint_latents=inpaint_latents, + video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, + )[0] noise_pred = noise_pred.float() diff --git a/cogvideox_fun/transformer_3d.py b/cogvideox_fun/transformer_3d.py index 8a607b4..2b57923 100644 --- a/cogvideox_fun/transformer_3d.py +++ b/cogvideox_fun/transformer_3d.py @@ -34,6 +34,7 @@ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from einops import rearrange try: from sageattention import sageattn SAGEATTN_IS_AVAVILABLE = True @@ -42,6 +43,23 @@ except: logger.info("sageattn not found, using sdpa") SAGEATTN_IS_AVAVILABLE = False +def fft(tensor): + tensor_fft = torch.fft.fft2(tensor) + tensor_fft_shifted = torch.fft.fftshift(tensor_fft) + B, C, H, W = tensor.size() + radius = min(H, W) // 5 + + Y, X = torch.meshgrid(torch.arange(H), torch.arange(W)) + center_x, center_y = W // 2, H // 2 + mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 + low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) + high_freq_mask = ~low_freq_mask + + low_freq_fft = tensor_fft_shifted * low_freq_mask + high_freq_fft = tensor_fft_shifted * high_freq_mask + + return low_freq_fft, high_freq_fft + class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -315,6 +333,11 @@ class CogVideoXBlock(nn.Module): encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + video_flow_feature: Optional[torch.Tensor] = None, + fuser=None, + fastercache_counter=0, + fastercache_start_step=15, + fastercache_device="cuda:0", ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -322,7 +345,39 @@ class CogVideoXBlock(nn.Module): norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) - + # Tora Motion-guidance Fuser + if video_flow_feature is not None: + H, W = video_flow_feature.shape[-2:] + T = norm_hidden_states.shape[1] // H // W + h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W) + h = fuser(h, video_flow_feature.to(h), T=T) + norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) + del h, fuser + #fastercache + B = norm_hidden_states.shape[0] + if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B: + attn_hidden_states = ( + self.cached_hidden_states[1][:B] + + (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B]) + * 0.3 + ).to(norm_hidden_states.device, non_blocking=True) + attn_encoder_hidden_states = ( + self.cached_encoder_hidden_states[1][:B] + + (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B]) + * 0.3 + ).to(norm_hidden_states.device, non_blocking=True) + else: + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + if fastercache_counter == fastercache_start_step: + self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] + self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)] + elif fastercache_counter > fastercache_start_step: + self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) + self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) # attention attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, @@ -497,6 +552,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False + self.fuser_list = None + + self.use_fastercache = False + self.fastercache_counter = 0 + self.fastercache_start_step = 15 + self.fastercache_lf_step = 40 + self.fastercache_hf_step = 30 + self.fastercache_device = "cuda" + def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value @@ -609,6 +673,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): inpaint_latents: Optional[torch.Tensor] = None, control_latents: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + video_flow_features: Optional[torch.Tensor] = None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape @@ -649,50 +714,101 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - # 4. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, + if self.use_fastercache: + self.fastercache_counter+=1 + if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states[:1], + encoder_hidden_states=encoder_hidden_states[:1], + temb=emb[:1], + image_rotary_emb=image_rotary_emb, + video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, + fuser = self.fuser_list[i] if self.fuser_list is not None else None, + fastercache_counter = self.fastercache_counter, + fastercache_device = self.fastercache_device ) + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 5. Final block + hidden_states = self.norm_out(hidden_states, temb=emb[:1]) + hidden_states = self.proj_out(hidden_states) + + # 6. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape(1, num_frames, height // p, width // p, channels, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + (bb, tt, cc, hh, ww) = output.shape + cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + lf_c, hf_c = fft(cond.float()) + #lf_step = 40 + #hf_step = 30 + if self.fastercache_counter <= self.fastercache_lf_step: + self.delta_lf = self.delta_lf * 1.1 + if self.fastercache_counter >= self.fastercache_hf_step: + self.delta_hf = self.delta_hf * 1.1 + + new_hf_uc = self.delta_hf + hf_c + new_lf_uc = self.delta_lf + lf_c + + combine_uc = new_lf_uc + new_hf_uc + combined_fft = torch.fft.ifftshift(combine_uc) + recovered_uncond = torch.fft.ifft2(combined_fft).real + recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + output = torch.cat([output, recovered_uncond]) + else: + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, image_rotary_emb=image_rotary_emb, + video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, + fuser = self.fuser_list[i] if self.fuser_list is not None else None, + fastercache_counter = self.fastercache_counter, + fastercache_device = self.fastercache_device ) - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] - # 5. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) + # 5. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) - # 6. Unpatchify - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + # 6. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if self.fastercache_counter >= self.fastercache_start_step + 1: + (bb, tt, cc, hh, ww) = output.shape + cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) + uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) + + lf_c, hf_c = fft(cond) + lf_uc, hf_uc = fft(uncond) + + self.delta_lf = lf_uc - lf_c + self.delta_hf = hf_uc - hf_c + if not return_dict: return (output,) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index d4bc137..dc33eec 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -662,7 +662,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): encoder_hidden_states=encoder_hidden_states[:1], temb=emb[:1], image_rotary_emb=image_rotary_emb, - video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, + video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, fastercache_counter = self.fastercache_counter, fastercache_device = self.fastercache_device diff --git a/nodes.py b/nodes.py index d7e72a4..5d2139f 100644 --- a/nodes.py +++ b/nodes.py @@ -1480,6 +1480,8 @@ class CogVideoXFunSampler: "opt_empty_latent": ("LATENT",), "noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}), "context_options": ("COGCONTEXT", ), + "tora_trajectory": ("TORAFEATURES", ), + "fastercache": ("FASTERCACHEARGS",), }, } @@ -1489,7 +1491,7 @@ class CogVideoXFunSampler: CATEGORY = "CogVideoWrapper" def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler, - start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None): + start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None, fastercache=None, tora_trajectory=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1538,6 +1540,20 @@ class CogVideoXFunSampler: else: context_frames, context_stride, context_overlap = None, None, None + if tora_trajectory is not None: + pipe.transformer.fuser_list = tora_trajectory["fuser_list"] + + if fastercache is not None: + pipe.transformer.use_fastercache = True + pipe.transformer.fastercache_counter = 0 + pipe.transformer.fastercache_start_step = fastercache["start_step"] + pipe.transformer.fastercache_lf_step = fastercache["lf_step"] + pipe.transformer.fastercache_hf_step = fastercache["hf_step"] + pipe.transformer.fastercache_device = fastercache["cache_device"] + else: + pipe.transformer.use_fastercache = False + pipe.transformer.fastercache_counter = 0 + generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 @@ -1564,7 +1580,8 @@ class CogVideoXFunSampler: context_frames=context_frames, context_stride= context_stride, context_overlap= context_overlap, - freenoise=context_options["freenoise"] if context_options is not None else None + freenoise=context_options["freenoise"] if context_options is not None else None, + tora=tora_trajectory if tora_trajectory is not None else None, ) #if not pipeline["cpu_offloading"]: # pipe.transformer.to(offload_device) diff --git a/tora/traj_module.py b/tora/traj_module.py index edec219..2cdecf0 100644 --- a/tora/traj_module.py +++ b/tora/traj_module.py @@ -287,11 +287,21 @@ class MGF(nn.Module): gamma_flow = self.flow_gamma_spatial(flow) beta_flow = self.flow_beta_spatial(flow) _, _, hh, wh = beta_flow.shape - gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T) - beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T) - gamma_flow = self.flow_gamma_temporal(gamma_flow) - beta_flow = self.flow_beta_temporal(beta_flow) - gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) - beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) + + if gamma_flow.shape[0] == 1: # Check if batch size is 1 + gamma_flow = rearrange(gamma_flow, "b c h w -> b c (h w)") + beta_flow = rearrange(beta_flow, "b c h w -> b c (h w)") + gamma_flow = self.flow_gamma_temporal(gamma_flow) + beta_flow = self.flow_beta_temporal(beta_flow) + gamma_flow = rearrange(gamma_flow, "b c (h w) -> b c h w", h=hh, w=wh) + beta_flow = rearrange(beta_flow, "b c (h w) -> b c h w", h=hh, w=wh) + else: + gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T) + beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T) + gamma_flow = self.flow_gamma_temporal(gamma_flow) + beta_flow = self.flow_beta_temporal(beta_flow) + gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) + beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) + h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow return h