From 0ea77bc63fd79800c073d4d6424c76f350b8d425 Mon Sep 17 00:00:00 2001 From: zhilemann <144045746+zhilemann@users.noreply.github.com> Date: Sun, 22 Dec 2024 03:02:23 +0300 Subject: [PATCH] add basic spatiotemporal guidance impl --- custom_cogvideox_transformer_3d.py | 90 +++++++++++++++++++++++- nodes.py | 34 ++++++++- pipeline_cogvideox.py | 107 +++++++++++++++++++++++++---- 3 files changed, 213 insertions(+), 18 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index efccb33..d66e369 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -113,6 +113,8 @@ class CogVideoXAttnProcessor2_0: raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.attention_mode = attention_mode self.attn_func = attn_func + self.stg_mode: str = None + def __call__( self, attn: Attention, @@ -121,6 +123,16 @@ class CogVideoXAttnProcessor2_0: attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.stg_mode == "STG-R": + return hidden_states, encoder_hidden_states + + if self.stg_mode == "STG-A": + hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0) + encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0) + + hidden_states = torch.cat([hidden_states_uncond, hidden_states], dim=0) + encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states], dim=0) + text_seq_length = encoder_hidden_states.size(1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -166,9 +178,9 @@ class CogVideoXAttnProcessor2_0: #feta if is_enhance_enabled(): feta_scores = get_feta_scores(attn, query, key, head_dim, text_seq_length) - + hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False) - + if self.attention_mode != "comfy": hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -184,6 +196,78 @@ class CogVideoXAttnProcessor2_0: if is_enhance_enabled(): hidden_states *= feta_scores + if self.stg_mode == "STG-A": + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(attn.to_q.weight.dtype) + + if not "fused" in self.attention_mode: + query_perturb = attn.to_q(hidden_states) + key_perturb = attn.to_k(hidden_states) + value_perturb = attn.to_v(hidden_states) + else: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1) + + inner_dim = key_perturb.shape[-1] + head_dim = inner_dim // attn.heads + + query_perturb = query_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_perturb = key_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_perturb = value_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query_perturb = attn.norm_q(query_perturb) + if attn.norm_k is not None: + key_perturb = attn.norm_k(key_perturb) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query_perturb[:, :, text_seq_length:] = apply_rotary_emb(query_perturb[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key_perturb[:, :, text_seq_length:] = apply_rotary_emb(key_perturb[:, :, text_seq_length:], image_rotary_emb) + + full_seq_length = query_perturb.size(2) + identity_block_size = full_seq_length - text_seq_length + + full_mask = torch.zeros((full_seq_length, full_seq_length), device=query_ptb.device, dtype=query_ptb.dtype) + + full_mask[:identity_block_size, :identity_block_size] = float("-inf") + full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) + + full_mask = full_mask.unsqueeze(0).unsqueeze(0) + + hidden_states_perturb = self.attn_func( + query_perturb, key_perturb, value_perturb, attn_mask=full_mask, is_causal=False + ) + + if self.attention_mode != "comfy": + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + + hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0) + encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_perturb], dim=0) + return hidden_states, encoder_hidden_states #region Blocks @@ -731,4 +815,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - \ No newline at end of file + diff --git a/nodes.py b/nodes.py index 470a54b..49ffb8a 100644 --- a/nodes.py +++ b/nodes.py @@ -1,4 +1,4 @@ -import os +import os, re import torch import json from einops import rearrange @@ -96,6 +96,31 @@ class CogVideoContextOptions: return (context_options,) +class CogVideoSTG: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mode": (["STG-A", "STG-R"],), + "blocks": ("STRING", {"default": "30", "tooltip": "Block index to apply STG"}), + "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Recommended values are ≤2.0"}), + "rescale": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Recommended values are ≤1.0"}), + }, + } + RETURN_TYPES = ("STGARGS",) + RETURN_NAMES = ("stg_args",) + FUNCTION = "setargs" + CATEGORY = "CogVideoWrapper" + DESCRIPTION = "https://github.com/junhahyung/STGuidance" + + def setargs(self, mode, blocks, scale, rescale): + return ({ + "stg_mode": mode, + "stg_layers_idx": [int(x) for x in re.findall("\d+", blocks)], + "stg_scale": scale, + "stg_rescaling": rescale + }, ) + class CogVideoTransformerEdit: @classmethod def INPUT_TYPES(s): @@ -612,6 +637,7 @@ class CogVideoSampler: "tora_trajectory": ("TORAFEATURES", ), "fastercache": ("FASTERCACHEARGS", ), "feta_args": ("FETAARGS", ), + "stg_args": ("STGARGS", ), } } @@ -621,7 +647,8 @@ class CogVideoSampler: CATEGORY = "CogVideoWrapper" def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, - denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None): + denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, + tora_trajectory=None, fastercache=None, feta_args=None, stg_args=None): mm.unload_all_models() mm.soft_empty_cache() @@ -743,6 +770,7 @@ class CogVideoSampler: image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0, image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0, feta_args=feta_args, + **(stg_args if stg_args else {}), ) if not model["cpu_offloading"] and model["manual_offloading"]: pipe.transformer.to(offload_device) @@ -973,6 +1001,7 @@ NODE_CLASS_MAPPINGS = { "CogVideoTextEncodeCombine": CogVideoTextEncodeCombine, "CogVideoTransformerEdit": CogVideoTransformerEdit, "CogVideoContextOptions": CogVideoContextOptions, + "CogVideoSTG": CogVideoSTG, "CogVideoControlNet": CogVideoControlNet, "ToraEncodeTrajectory": ToraEncodeTrajectory, "ToraEncodeOpticalFlow": ToraEncodeOpticalFlow, @@ -991,6 +1020,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoTextEncodeCombine": "CogVideo TextEncode Combine", "CogVideoTransformerEdit": "CogVideo TransformerEdit", "CogVideoContextOptions": "CogVideo Context Options", + "CogVideoSTG": "CogVideo Spatiotemporal Guidance", "ToraEncodeTrajectory": "Tora Encode Trajectory", "ToraEncodeOpticalFlow": "Tora Encode OpticalFlow", "CogVideoXFasterCache": "CogVideoX FasterCache", diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index cb08e1e..c4ec8a1 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -316,6 +316,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): def do_classifier_free_guidance(self): return self._guidance_scale > 1 + @property + def do_spatiotemporal_guidance(self): + return self._stg_scale > 0 + @property def num_timesteps(self): return self._num_timesteps @@ -324,6 +328,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): def interrupt(self): return self._interrupt + def extract_attn_layers(self): + for key, mod in self.transformer.named_modules(): + if "attn1" in key and "to" not in key and "add" not in key and "norm" not in key: + yield mod + @torch.no_grad() def __call__( self, @@ -353,6 +362,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): image_cond_start_percent: float = 0.0, image_cond_end_percent: float = 1.0, feta_args: Optional[dict] = None, + stg_mode: Optional[str] = None, # "STG-A", + stg_layers_idx: Optional[List[int]] = [], # [30], + stg_scale: Optional[float] = 0.0, # 4.0 + stg_rescaling: Optional[float] = None, # 0.7 ): """ @@ -409,9 +422,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): prompt_embeds, negative_prompt_embeds, ) + self._stg_scale = stg_scale self._guidance_scale = guidance_scale self._interrupt = False + for idx, mod in enumerate(self.extract_attn_layers()): + mod.stg_mode = stg_mode if idx in stg_layers_idx else None + # 2. Default call parameters batch_size = prompt_embeds.shape[0] @@ -421,8 +438,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale[0] > 1.0 - if do_classifier_free_guidance: + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0) + prompt_embeds = prompt_embeds.to(self.vae_dtype) # 4. Prepare timesteps @@ -615,13 +635,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): disable_enhance() # region context schedule sampling if use_context_schedule: - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: + latent_model_input = torch.cat([latents] * 2) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + latent_model_input = torch.cat([latents] * 3) + else: + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) counter = torch.zeros_like(latent_model_input) noise_pred = torch.zeros_like(latent_model_input) if image_cond_latents is not None: - latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: + latent_image_input = torch.cat([image_cond_latents] * 2) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + latent_image_input = torch.cat([image_cond_latents] * 3) + else: + latent_image_input = image_cond_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -708,9 +740,19 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): noise_pred = noise_pred.float() noise_pred /= counter - if do_classifier_free_guidance: + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) + + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) \ + + self._stg_scale * (noise_pred_text - noise_pred_perturb) + + if stg_rescaling: + factor = noise_pred_text.std() / noise_pred.std() + factor = stg_rescaling * factor + (1 - stg_rescaling) + noise_pred = noise_pred * factor # compute the previous noisy sample x_t -> x_t-1 if not isinstance(self.scheduler, CogVideoXDPMScheduler): @@ -733,25 +775,54 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): # region sampling else: - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: + latent_model_input = torch.cat([latents] * 2) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + latent_model_input = torch.cat([latents] * 3) + else: + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if image_cond_latents is not None: if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent: latent_image_input = torch.zeros_like(latent_model_input) else: - latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: + latent_image_input = torch.cat([image_cond_latents] * 2) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + latent_image_input = torch.cat([image_cond_latents] * 3) + else: + latent_image_input = image_cond_latents + if fun_mask is not None: #for fun img2vid and interpolation - fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: + fun_inpaint_mask = torch.cat([fun_mask] * 2) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + fun_inpaint_mask = torch.cat([fun_mask] * 3) + else: + fun_inpaint_mask = fun_mask + masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2) latent_model_input = torch.cat([latent_model_input, masks_input], dim=2) else: latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) else: # for Fun inpaint vid2vid if fun_mask is not None: - fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask - fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents - fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype) + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: + fun_inpaint_mask = torch.cat([fun_mask] * 2) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + fun_inpaint_mask = torch.cat([fun_mask] * 3) + else: + fun_inpaint_mask = fun_mask + + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: + fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 3) + else: + fun_inpaint_masked_video_latents = fun_masked_video_latents + latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -791,11 +862,21 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): self._guidance_scale[i] = 1 + guidance_scale[i] * ( (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) - - if do_classifier_free_guidance: + + if do_classifier_free_guidance and not self.do_spatiotemporal_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) + elif do_classifier_free_guidance and self.do_spatiotemporal_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) \ + + self._stg_scale * (noise_pred_text - noise_pred_perturb) + + if stg_rescaling: + factor = noise_pred_text.std() / noise_pred.std() + factor = stg_rescaling * factor + (1 - stg_rescaling) + noise_pred = noise_pred * factor + # compute the previous noisy sample x_t -> x_t-1 if not isinstance(self.scheduler, CogVideoXDPMScheduler): latents = self.scheduler.step(noise_pred, t, latents.to(self.vae_dtype), **extra_step_kwargs, return_dict=False)[0] @@ -822,4 +903,4 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): # Offload all models self.maybe_free_model_hooks() - return latents \ No newline at end of file + return latents