mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 07:04:24 +08:00
parent
fcc0f3e65a
commit
f16d38a5d2
@ -35,6 +35,9 @@ from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
from .embeddings import CogVideoXPatchEmbed
|
||||
|
||||
from .enhance_a_video.enhance import get_feta_scores
|
||||
from .enhance_a_video.globals import is_enhance_enabled, set_num_frames
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@ -159,6 +162,10 @@ class CogVideoXAttnProcessor2_0:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
#feta
|
||||
if is_enhance_enabled():
|
||||
feta_scores = get_feta_scores(attn, query, key, head_dim, text_seq_length)
|
||||
|
||||
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
|
||||
|
||||
@ -173,6 +180,10 @@ class CogVideoXAttnProcessor2_0:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
|
||||
if is_enhance_enabled():
|
||||
hidden_states *= feta_scores
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
#region Blocks
|
||||
@ -543,6 +554,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
set_num_frames(num_frames)
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
|
||||
0
enhance_a_video/__init__.py
Normal file
0
enhance_a_video/__init__.py
Normal file
82
enhance_a_video/enhance.py
Normal file
82
enhance_a_video/enhance.py
Normal file
@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from diffusers.models.attention import Attention
|
||||
from .globals import get_enhance_weight, get_num_frames
|
||||
|
||||
# def get_feta_scores(query, key):
|
||||
# img_q, img_k = query, key
|
||||
|
||||
# num_frames = get_num_frames()
|
||||
|
||||
# B, S, N, C = img_q.shape
|
||||
|
||||
# # Calculate spatial dimension
|
||||
# spatial_dim = S // num_frames
|
||||
|
||||
# # Add time dimension between spatial and head dims
|
||||
# query_image = img_q.reshape(B, spatial_dim, num_frames, N, C)
|
||||
# key_image = img_k.reshape(B, spatial_dim, num_frames, N, C)
|
||||
|
||||
# # Expand time dimension
|
||||
# query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
|
||||
# key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
|
||||
|
||||
# # Reshape to match feta_score input format: [(B S) N T C]
|
||||
# query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128])
|
||||
# key_image = rearrange(key_image, "b s t n c -> (b s) n t c")
|
||||
|
||||
# return feta_score(query_image, key_image, C, num_frames)
|
||||
|
||||
def get_feta_scores(
|
||||
attn: Attention,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_dim: int,
|
||||
text_seq_length: int,
|
||||
) -> torch.Tensor:
|
||||
num_frames = get_num_frames()
|
||||
spatial_dim = int((query.shape[2] - text_seq_length) / num_frames)
|
||||
|
||||
query_image = rearrange(
|
||||
query[:, :, text_seq_length:],
|
||||
"B N (T S) C -> (B S) N T C",
|
||||
N=attn.heads,
|
||||
T=num_frames,
|
||||
S=spatial_dim,
|
||||
C=head_dim,
|
||||
)
|
||||
key_image = rearrange(
|
||||
key[:, :, text_seq_length:],
|
||||
"B N (T S) C -> (B S) N T C",
|
||||
N=attn.heads,
|
||||
T=num_frames,
|
||||
S=spatial_dim,
|
||||
C=head_dim,
|
||||
)
|
||||
return feta_score(query_image, key_image, head_dim, num_frames)
|
||||
|
||||
def feta_score(query_image, key_image, head_dim, num_frames):
|
||||
scale = head_dim**-0.5
|
||||
query_image = query_image * scale
|
||||
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
|
||||
attn_temp = attn_temp.to(torch.float32)
|
||||
attn_temp = attn_temp.softmax(dim=-1)
|
||||
|
||||
# Reshape to [batch_size * num_tokens, num_frames, num_frames]
|
||||
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
|
||||
|
||||
# Create a mask for diagonal elements
|
||||
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
||||
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
||||
|
||||
# Zero out diagonal elements
|
||||
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
|
||||
|
||||
# Calculate mean for each token's attention matrix
|
||||
# Number of off-diagonal elements per matrix is n*n - n
|
||||
num_off_diag = num_frames * num_frames - num_frames
|
||||
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
||||
|
||||
enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
|
||||
enhance_scores = enhance_scores.clamp(min=1)
|
||||
return enhance_scores
|
||||
31
enhance_a_video/globals.py
Normal file
31
enhance_a_video/globals.py
Normal file
@ -0,0 +1,31 @@
|
||||
NUM_FRAMES = None
|
||||
FETA_WEIGHT = None
|
||||
ENABLE_FETA = False
|
||||
|
||||
def set_num_frames(num_frames: int):
|
||||
global NUM_FRAMES
|
||||
NUM_FRAMES = num_frames
|
||||
|
||||
|
||||
def get_num_frames() -> int:
|
||||
return NUM_FRAMES
|
||||
|
||||
|
||||
def enable_enhance():
|
||||
global ENABLE_FETA
|
||||
ENABLE_FETA = True
|
||||
|
||||
def disable_enhance():
|
||||
global ENABLE_FETA
|
||||
ENABLE_FETA = False
|
||||
|
||||
def is_enhance_enabled() -> bool:
|
||||
return ENABLE_FETA
|
||||
|
||||
def set_enhance_weight(feta_weight: float):
|
||||
global FETA_WEIGHT
|
||||
FETA_WEIGHT = feta_weight
|
||||
|
||||
|
||||
def get_enhance_weight() -> float:
|
||||
return FETA_WEIGHT
|
||||
25
nodes.py
25
nodes.py
@ -49,6 +49,25 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
|
||||
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
||||
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
||||
|
||||
class CogVideoEnhanceAVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"weight": ("FLOAT", {"default": 1.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("FETAARGS",)
|
||||
RETURN_NAMES = ("feta_args",)
|
||||
FUNCTION = "setargs"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
|
||||
|
||||
def setargs(self, **kwargs):
|
||||
return (kwargs, )
|
||||
|
||||
class CogVideoContextOptions:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -592,6 +611,7 @@ class CogVideoSampler:
|
||||
"controlnet": ("COGVIDECONTROLNET",),
|
||||
"tora_trajectory": ("TORAFEATURES", ),
|
||||
"fastercache": ("FASTERCACHEARGS", ),
|
||||
"feta_args": ("FETAARGS", ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -601,7 +621,7 @@ class CogVideoSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
@ -722,6 +742,7 @@ class CogVideoSampler:
|
||||
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
|
||||
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
|
||||
feta_args=feta_args,
|
||||
)
|
||||
if not model["cpu_offloading"] and model["manual_offloading"]:
|
||||
pipe.transformer.to(offload_device)
|
||||
@ -960,6 +981,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoLatentPreview": CogVideoLatentPreview,
|
||||
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
|
||||
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
|
||||
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoSampler": "CogVideo Sampler",
|
||||
@ -976,4 +998,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoLatentPreview": "CogVideo LatentPreview",
|
||||
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
|
||||
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
|
||||
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
|
||||
}
|
||||
|
||||
@ -29,6 +29,7 @@ from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||
|
||||
from .embeddings import get_3d_rotary_pos_embed
|
||||
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight
|
||||
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
@ -351,6 +352,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
tora: Optional[dict] = None,
|
||||
image_cond_start_percent: float = 0.0,
|
||||
image_cond_end_percent: float = 1.0,
|
||||
feta_args: Optional[dict] = None,
|
||||
|
||||
):
|
||||
"""
|
||||
@ -573,7 +575,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
else:
|
||||
controlnet_states = None
|
||||
control_weights= None
|
||||
|
||||
# 9. Tora
|
||||
if tora is not None:
|
||||
trajectory_length = tora["video_flow_features"].shape[1]
|
||||
logger.info(f"Tora trajectory length: {trajectory_length}")
|
||||
@ -585,16 +587,32 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
||||
|
||||
if feta_args is not None:
|
||||
set_enhance_weight(feta_args["weight"])
|
||||
feta_start_percent = feta_args["start_percent"]
|
||||
feta_end_percent = feta_args["end_percent"]
|
||||
enable_enhance()
|
||||
else:
|
||||
disable_enhance()
|
||||
|
||||
# 11. Denoising loop
|
||||
from .latent_preview import prepare_callback
|
||||
callback = prepare_callback(self.transformer, num_inference_steps)
|
||||
|
||||
# 9. Denoising loop
|
||||
comfy_pbar = ProgressBar(len(timesteps))
|
||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||
old_pred_original_sample = None # for DPM-solver++
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
if feta_args is not None:
|
||||
if feta_start_percent <= current_step_percentage <= feta_end_percent:
|
||||
enable_enhance()
|
||||
else:
|
||||
disable_enhance()
|
||||
# region context schedule sampling
|
||||
if use_context_schedule:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@ -609,8 +627,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
# use same rotary embeddings for all context windows
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||
@ -720,8 +736,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
if image_cond_latents is not None:
|
||||
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
|
||||
latent_image_input = torch.zeros_like(latent_model_input)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user