Add Enhance-A-Video

https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video
This commit is contained in:
kijai 2024-12-22 01:26:18 +02:00
parent fcc0f3e65a
commit f16d38a5d2
6 changed files with 170 additions and 7 deletions

View File

@ -35,6 +35,9 @@ from diffusers.loaders import PeftAdapterMixin
from diffusers.models.embeddings import apply_rotary_emb
from .embeddings import CogVideoXPatchEmbed
from .enhance_a_video.enhance import get_feta_scores
from .enhance_a_video.globals import is_enhance_enabled, set_num_frames
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -159,6 +162,10 @@ class CogVideoXAttnProcessor2_0:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
#feta
if is_enhance_enabled():
feta_scores = get_feta_scores(attn, query, key, head_dim, text_seq_length)
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
@ -173,6 +180,10 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
if is_enhance_enabled():
hidden_states *= feta_scores
return hidden_states, encoder_hidden_states
#region Blocks
@ -543,6 +554,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
set_num_frames(num_frames)
# 1. Time embedding
timesteps = timestep

View File

View File

@ -0,0 +1,82 @@
import torch
from einops import rearrange
from diffusers.models.attention import Attention
from .globals import get_enhance_weight, get_num_frames
# def get_feta_scores(query, key):
# img_q, img_k = query, key
# num_frames = get_num_frames()
# B, S, N, C = img_q.shape
# # Calculate spatial dimension
# spatial_dim = S // num_frames
# # Add time dimension between spatial and head dims
# query_image = img_q.reshape(B, spatial_dim, num_frames, N, C)
# key_image = img_k.reshape(B, spatial_dim, num_frames, N, C)
# # Expand time dimension
# query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
# key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
# # Reshape to match feta_score input format: [(B S) N T C]
# query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128])
# key_image = rearrange(key_image, "b s t n c -> (b s) n t c")
# return feta_score(query_image, key_image, C, num_frames)
def get_feta_scores(
attn: Attention,
query: torch.Tensor,
key: torch.Tensor,
head_dim: int,
text_seq_length: int,
) -> torch.Tensor:
num_frames = get_num_frames()
spatial_dim = int((query.shape[2] - text_seq_length) / num_frames)
query_image = rearrange(
query[:, :, text_seq_length:],
"B N (T S) C -> (B S) N T C",
N=attn.heads,
T=num_frames,
S=spatial_dim,
C=head_dim,
)
key_image = rearrange(
key[:, :, text_seq_length:],
"B N (T S) C -> (B S) N T C",
N=attn.heads,
T=num_frames,
S=spatial_dim,
C=head_dim,
)
return feta_score(query_image, key_image, head_dim, num_frames)
def feta_score(query_image, key_image, head_dim, num_frames):
scale = head_dim**-0.5
query_image = query_image * scale
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
attn_temp = attn_temp.to(torch.float32)
attn_temp = attn_temp.softmax(dim=-1)
# Reshape to [batch_size * num_tokens, num_frames, num_frames]
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
# Create a mask for diagonal elements
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
# Zero out diagonal elements
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
# Calculate mean for each token's attention matrix
# Number of off-diagonal elements per matrix is n*n - n
num_off_diag = num_frames * num_frames - num_frames
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
enhance_scores = enhance_scores.clamp(min=1)
return enhance_scores

View File

@ -0,0 +1,31 @@
NUM_FRAMES = None
FETA_WEIGHT = None
ENABLE_FETA = False
def set_num_frames(num_frames: int):
global NUM_FRAMES
NUM_FRAMES = num_frames
def get_num_frames() -> int:
return NUM_FRAMES
def enable_enhance():
global ENABLE_FETA
ENABLE_FETA = True
def disable_enhance():
global ENABLE_FETA
ENABLE_FETA = False
def is_enhance_enabled() -> bool:
return ENABLE_FETA
def set_enhance_weight(feta_weight: float):
global FETA_WEIGHT
FETA_WEIGHT = feta_weight
def get_enhance_weight() -> float:
return FETA_WEIGHT

View File

@ -49,6 +49,25 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
class CogVideoEnhanceAVideo:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"weight": ("FLOAT", {"default": 1.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}),
},
}
RETURN_TYPES = ("FETAARGS",)
RETURN_NAMES = ("feta_args",)
FUNCTION = "setargs"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
def setargs(self, **kwargs):
return (kwargs, )
class CogVideoContextOptions:
@classmethod
def INPUT_TYPES(s):
@ -592,6 +611,7 @@ class CogVideoSampler:
"controlnet": ("COGVIDECONTROLNET",),
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ),
"feta_args": ("FETAARGS", ),
}
}
@ -601,7 +621,7 @@ class CogVideoSampler:
CATEGORY = "CogVideoWrapper"
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
mm.unload_all_models()
mm.soft_empty_cache()
@ -722,6 +742,7 @@ class CogVideoSampler:
tora=tora_trajectory if tora_trajectory is not None else None,
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
feta_args=feta_args,
)
if not model["cpu_offloading"] and model["manual_offloading"]:
pipe.transformer.to(offload_device)
@ -960,6 +981,7 @@ NODE_CLASS_MAPPINGS = {
"CogVideoLatentPreview": CogVideoLatentPreview,
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler",
@ -976,4 +998,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoLatentPreview": "CogVideo LatentPreview",
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
}

View File

@ -29,6 +29,7 @@ from diffusers.loaders import CogVideoXLoraLoaderMixin
from .embeddings import get_3d_rotary_pos_embed
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight
from comfy.utils import ProgressBar
@ -351,6 +352,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
tora: Optional[dict] = None,
image_cond_start_percent: float = 0.0,
image_cond_end_percent: float = 1.0,
feta_args: Optional[dict] = None,
):
"""
@ -573,7 +575,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
else:
controlnet_states = None
control_weights= None
# 9. Tora
if tora is not None:
trajectory_length = tora["video_flow_features"].shape[1]
logger.info(f"Tora trajectory length: {trajectory_length}")
@ -585,16 +587,32 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
if feta_args is not None:
set_enhance_weight(feta_args["weight"])
feta_start_percent = feta_args["start_percent"]
feta_end_percent = feta_args["end_percent"]
enable_enhance()
else:
disable_enhance()
# 11. Denoising loop
from .latent_preview import prepare_callback
callback = prepare_callback(self.transformer, num_inference_steps)
# 9. Denoising loop
comfy_pbar = ProgressBar(len(timesteps))
with self.progress_bar(total=len(timesteps)) as progress_bar:
old_pred_original_sample = None # for DPM-solver++
for i, t in enumerate(timesteps):
if self.interrupt:
continue
current_step_percentage = i / num_inference_steps
if feta_args is not None:
if feta_start_percent <= current_step_percentage <= feta_end_percent:
enable_enhance()
else:
disable_enhance()
# region context schedule sampling
if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@ -609,8 +627,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
current_step_percentage = i / num_inference_steps
# use same rotary embeddings for all context windows
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
@ -720,8 +736,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_step_percentage = i / num_inference_steps
if image_cond_latents is not None:
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
latent_image_input = torch.zeros_like(latent_model_input)