diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 4cb64a6..efccb33 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -35,6 +35,9 @@ from diffusers.loaders import PeftAdapterMixin from diffusers.models.embeddings import apply_rotary_emb from .embeddings import CogVideoXPatchEmbed +from .enhance_a_video.enhance import get_feta_scores +from .enhance_a_video.globals import is_enhance_enabled, set_num_frames + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -159,6 +162,10 @@ class CogVideoXAttnProcessor2_0: query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + #feta + if is_enhance_enabled(): + feta_scores = get_feta_scores(attn, query, key, head_dim, text_seq_length) hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False) @@ -173,6 +180,10 @@ class CogVideoXAttnProcessor2_0: encoder_hidden_states, hidden_states = hidden_states.split( [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 ) + + if is_enhance_enabled(): + hidden_states *= feta_scores + return hidden_states, encoder_hidden_states #region Blocks @@ -543,6 +554,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape + + set_num_frames(num_frames) # 1. Time embedding timesteps = timestep diff --git a/enhance_a_video/__init__.py b/enhance_a_video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/enhance_a_video/enhance.py b/enhance_a_video/enhance.py new file mode 100644 index 0000000..12e2fa7 --- /dev/null +++ b/enhance_a_video/enhance.py @@ -0,0 +1,82 @@ +import torch +from einops import rearrange +from diffusers.models.attention import Attention +from .globals import get_enhance_weight, get_num_frames + +# def get_feta_scores(query, key): +# img_q, img_k = query, key + +# num_frames = get_num_frames() + +# B, S, N, C = img_q.shape + +# # Calculate spatial dimension +# spatial_dim = S // num_frames + +# # Add time dimension between spatial and head dims +# query_image = img_q.reshape(B, spatial_dim, num_frames, N, C) +# key_image = img_k.reshape(B, spatial_dim, num_frames, N, C) + +# # Expand time dimension +# query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] +# key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] + +# # Reshape to match feta_score input format: [(B S) N T C] +# query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128]) +# key_image = rearrange(key_image, "b s t n c -> (b s) n t c") + +# return feta_score(query_image, key_image, C, num_frames) + +def get_feta_scores( + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + head_dim: int, + text_seq_length: int, + ) -> torch.Tensor: + num_frames = get_num_frames() + spatial_dim = int((query.shape[2] - text_seq_length) / num_frames) + + query_image = rearrange( + query[:, :, text_seq_length:], + "B N (T S) C -> (B S) N T C", + N=attn.heads, + T=num_frames, + S=spatial_dim, + C=head_dim, + ) + key_image = rearrange( + key[:, :, text_seq_length:], + "B N (T S) C -> (B S) N T C", + N=attn.heads, + T=num_frames, + S=spatial_dim, + C=head_dim, + ) + return feta_score(query_image, key_image, head_dim, num_frames) + +def feta_score(query_image, key_image, head_dim, num_frames): + scale = head_dim**-0.5 + query_image = query_image * scale + attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 + attn_temp = attn_temp.to(torch.float32) + attn_temp = attn_temp.softmax(dim=-1) + + # Reshape to [batch_size * num_tokens, num_frames, num_frames] + attn_temp = attn_temp.reshape(-1, num_frames, num_frames) + + # Create a mask for diagonal elements + diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() + diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) + + # Zero out diagonal elements + attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) + + # Calculate mean for each token's attention matrix + # Number of off-diagonal elements per matrix is n*n - n + num_off_diag = num_frames * num_frames - num_frames + mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag + + enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight()) + enhance_scores = enhance_scores.clamp(min=1) + return enhance_scores diff --git a/enhance_a_video/globals.py b/enhance_a_video/globals.py new file mode 100644 index 0000000..50d0da2 --- /dev/null +++ b/enhance_a_video/globals.py @@ -0,0 +1,31 @@ +NUM_FRAMES = None +FETA_WEIGHT = None +ENABLE_FETA = False + +def set_num_frames(num_frames: int): + global NUM_FRAMES + NUM_FRAMES = num_frames + + +def get_num_frames() -> int: + return NUM_FRAMES + + +def enable_enhance(): + global ENABLE_FETA + ENABLE_FETA = True + +def disable_enhance(): + global ENABLE_FETA + ENABLE_FETA = False + +def is_enhance_enabled() -> bool: + return ENABLE_FETA + +def set_enhance_weight(feta_weight: float): + global FETA_WEIGHT + FETA_WEIGHT = feta_weight + + +def get_enhance_weight() -> float: + return FETA_WEIGHT diff --git a/nodes.py b/nodes.py index 08997b0..470a54b 100644 --- a/nodes.py +++ b/nodes.py @@ -49,6 +49,25 @@ if not "CogVideo" in folder_paths.folder_names_and_paths: if not "cogvideox_loras" in folder_paths.folder_names_and_paths: folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras")) +class CogVideoEnhanceAVideo: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "weight": ("FLOAT", {"default": 1.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}), + }, + } + RETURN_TYPES = ("FETAARGS",) + RETURN_NAMES = ("feta_args",) + FUNCTION = "setargs" + CATEGORY = "CogVideoWrapper" + DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" + + def setargs(self, **kwargs): + return (kwargs, ) + class CogVideoContextOptions: @classmethod def INPUT_TYPES(s): @@ -592,6 +611,7 @@ class CogVideoSampler: "controlnet": ("COGVIDECONTROLNET",), "tora_trajectory": ("TORAFEATURES", ), "fastercache": ("FASTERCACHEARGS", ), + "feta_args": ("FETAARGS", ), } } @@ -601,7 +621,7 @@ class CogVideoSampler: CATEGORY = "CogVideoWrapper" def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, - denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None): + denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None): mm.unload_all_models() mm.soft_empty_cache() @@ -722,6 +742,7 @@ class CogVideoSampler: tora=tora_trajectory if tora_trajectory is not None else None, image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0, image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0, + feta_args=feta_args, ) if not model["cpu_offloading"] and model["manual_offloading"]: pipe.transformer.to(offload_device) @@ -960,6 +981,7 @@ NODE_CLASS_MAPPINGS = { "CogVideoLatentPreview": CogVideoLatentPreview, "CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings, "CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP, + "CogVideoEnhanceAVideo": CogVideoEnhanceAVideo, } NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoSampler": "CogVideo Sampler", @@ -976,4 +998,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoLatentPreview": "CogVideo LatentPreview", "CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings", "CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP", + "CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video", } diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 17702f1..cb08e1e 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -29,6 +29,7 @@ from diffusers.loaders import CogVideoXLoraLoaderMixin from .embeddings import get_3d_rotary_pos_embed from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel +from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight from comfy.utils import ProgressBar @@ -351,6 +352,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): tora: Optional[dict] = None, image_cond_start_percent: float = 0.0, image_cond_end_percent: float = 1.0, + feta_args: Optional[dict] = None, ): """ @@ -573,7 +575,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): else: controlnet_states = None control_weights= None - + # 9. Tora if tora is not None: trajectory_length = tora["video_flow_features"].shape[1] logger.info(f"Tora trajectory length: {trajectory_length}") @@ -585,16 +587,32 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps") + if feta_args is not None: + set_enhance_weight(feta_args["weight"]) + feta_start_percent = feta_args["start_percent"] + feta_end_percent = feta_args["end_percent"] + enable_enhance() + else: + disable_enhance() + + # 11. Denoising loop from .latent_preview import prepare_callback callback = prepare_callback(self.transformer, num_inference_steps) - # 9. Denoising loop comfy_pbar = ProgressBar(len(timesteps)) with self.progress_bar(total=len(timesteps)) as progress_bar: old_pred_original_sample = None # for DPM-solver++ for i, t in enumerate(timesteps): if self.interrupt: continue + + current_step_percentage = i / num_inference_steps + + if feta_args is not None: + if feta_start_percent <= current_step_percentage <= feta_end_percent: + enable_enhance() + else: + disable_enhance() # region context schedule sampling if use_context_schedule: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -609,8 +627,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - current_step_percentage = i / num_inference_steps - # use same rotary embeddings for all context windows image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, context_frames, device) @@ -720,8 +736,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - current_step_percentage = i / num_inference_steps - if image_cond_latents is not None: if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent: latent_image_input = torch.zeros_like(latent_model_input)