mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-29 06:17:09 +08:00
parent
fcc0f3e65a
commit
f16d38a5d2
@ -35,6 +35,9 @@ from diffusers.loaders import PeftAdapterMixin
|
|||||||
from diffusers.models.embeddings import apply_rotary_emb
|
from diffusers.models.embeddings import apply_rotary_emb
|
||||||
from .embeddings import CogVideoXPatchEmbed
|
from .embeddings import CogVideoXPatchEmbed
|
||||||
|
|
||||||
|
from .enhance_a_video.enhance import get_feta_scores
|
||||||
|
from .enhance_a_video.globals import is_enhance_enabled, set_num_frames
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@ -159,6 +162,10 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
|
||||||
|
#feta
|
||||||
|
if is_enhance_enabled():
|
||||||
|
feta_scores = get_feta_scores(attn, query, key, head_dim, text_seq_length)
|
||||||
|
|
||||||
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
|
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
|
||||||
|
|
||||||
@ -173,6 +180,10 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_enhance_enabled():
|
||||||
|
hidden_states *= feta_scores
|
||||||
|
|
||||||
return hidden_states, encoder_hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
#region Blocks
|
#region Blocks
|
||||||
@ -543,6 +554,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||||
|
|
||||||
|
set_num_frames(num_frames)
|
||||||
|
|
||||||
# 1. Time embedding
|
# 1. Time embedding
|
||||||
timesteps = timestep
|
timesteps = timestep
|
||||||
|
|||||||
0
enhance_a_video/__init__.py
Normal file
0
enhance_a_video/__init__.py
Normal file
82
enhance_a_video/enhance.py
Normal file
82
enhance_a_video/enhance.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from diffusers.models.attention import Attention
|
||||||
|
from .globals import get_enhance_weight, get_num_frames
|
||||||
|
|
||||||
|
# def get_feta_scores(query, key):
|
||||||
|
# img_q, img_k = query, key
|
||||||
|
|
||||||
|
# num_frames = get_num_frames()
|
||||||
|
|
||||||
|
# B, S, N, C = img_q.shape
|
||||||
|
|
||||||
|
# # Calculate spatial dimension
|
||||||
|
# spatial_dim = S // num_frames
|
||||||
|
|
||||||
|
# # Add time dimension between spatial and head dims
|
||||||
|
# query_image = img_q.reshape(B, spatial_dim, num_frames, N, C)
|
||||||
|
# key_image = img_k.reshape(B, spatial_dim, num_frames, N, C)
|
||||||
|
|
||||||
|
# # Expand time dimension
|
||||||
|
# query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
|
||||||
|
# key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
|
||||||
|
|
||||||
|
# # Reshape to match feta_score input format: [(B S) N T C]
|
||||||
|
# query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128])
|
||||||
|
# key_image = rearrange(key_image, "b s t n c -> (b s) n t c")
|
||||||
|
|
||||||
|
# return feta_score(query_image, key_image, C, num_frames)
|
||||||
|
|
||||||
|
def get_feta_scores(
|
||||||
|
attn: Attention,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
head_dim: int,
|
||||||
|
text_seq_length: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_frames = get_num_frames()
|
||||||
|
spatial_dim = int((query.shape[2] - text_seq_length) / num_frames)
|
||||||
|
|
||||||
|
query_image = rearrange(
|
||||||
|
query[:, :, text_seq_length:],
|
||||||
|
"B N (T S) C -> (B S) N T C",
|
||||||
|
N=attn.heads,
|
||||||
|
T=num_frames,
|
||||||
|
S=spatial_dim,
|
||||||
|
C=head_dim,
|
||||||
|
)
|
||||||
|
key_image = rearrange(
|
||||||
|
key[:, :, text_seq_length:],
|
||||||
|
"B N (T S) C -> (B S) N T C",
|
||||||
|
N=attn.heads,
|
||||||
|
T=num_frames,
|
||||||
|
S=spatial_dim,
|
||||||
|
C=head_dim,
|
||||||
|
)
|
||||||
|
return feta_score(query_image, key_image, head_dim, num_frames)
|
||||||
|
|
||||||
|
def feta_score(query_image, key_image, head_dim, num_frames):
|
||||||
|
scale = head_dim**-0.5
|
||||||
|
query_image = query_image * scale
|
||||||
|
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
|
||||||
|
attn_temp = attn_temp.to(torch.float32)
|
||||||
|
attn_temp = attn_temp.softmax(dim=-1)
|
||||||
|
|
||||||
|
# Reshape to [batch_size * num_tokens, num_frames, num_frames]
|
||||||
|
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
|
||||||
|
|
||||||
|
# Create a mask for diagonal elements
|
||||||
|
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
||||||
|
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
||||||
|
|
||||||
|
# Zero out diagonal elements
|
||||||
|
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
|
||||||
|
|
||||||
|
# Calculate mean for each token's attention matrix
|
||||||
|
# Number of off-diagonal elements per matrix is n*n - n
|
||||||
|
num_off_diag = num_frames * num_frames - num_frames
|
||||||
|
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
||||||
|
|
||||||
|
enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
|
||||||
|
enhance_scores = enhance_scores.clamp(min=1)
|
||||||
|
return enhance_scores
|
||||||
31
enhance_a_video/globals.py
Normal file
31
enhance_a_video/globals.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
NUM_FRAMES = None
|
||||||
|
FETA_WEIGHT = None
|
||||||
|
ENABLE_FETA = False
|
||||||
|
|
||||||
|
def set_num_frames(num_frames: int):
|
||||||
|
global NUM_FRAMES
|
||||||
|
NUM_FRAMES = num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_frames() -> int:
|
||||||
|
return NUM_FRAMES
|
||||||
|
|
||||||
|
|
||||||
|
def enable_enhance():
|
||||||
|
global ENABLE_FETA
|
||||||
|
ENABLE_FETA = True
|
||||||
|
|
||||||
|
def disable_enhance():
|
||||||
|
global ENABLE_FETA
|
||||||
|
ENABLE_FETA = False
|
||||||
|
|
||||||
|
def is_enhance_enabled() -> bool:
|
||||||
|
return ENABLE_FETA
|
||||||
|
|
||||||
|
def set_enhance_weight(feta_weight: float):
|
||||||
|
global FETA_WEIGHT
|
||||||
|
FETA_WEIGHT = feta_weight
|
||||||
|
|
||||||
|
|
||||||
|
def get_enhance_weight() -> float:
|
||||||
|
return FETA_WEIGHT
|
||||||
25
nodes.py
25
nodes.py
@ -49,6 +49,25 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
|
|||||||
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
||||||
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
||||||
|
|
||||||
|
class CogVideoEnhanceAVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"weight": ("FLOAT", {"default": 1.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("FETAARGS",)
|
||||||
|
RETURN_NAMES = ("feta_args",)
|
||||||
|
FUNCTION = "setargs"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
|
||||||
|
|
||||||
|
def setargs(self, **kwargs):
|
||||||
|
return (kwargs, )
|
||||||
|
|
||||||
class CogVideoContextOptions:
|
class CogVideoContextOptions:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -592,6 +611,7 @@ class CogVideoSampler:
|
|||||||
"controlnet": ("COGVIDECONTROLNET",),
|
"controlnet": ("COGVIDECONTROLNET",),
|
||||||
"tora_trajectory": ("TORAFEATURES", ),
|
"tora_trajectory": ("TORAFEATURES", ),
|
||||||
"fastercache": ("FASTERCACHEARGS", ),
|
"fastercache": ("FASTERCACHEARGS", ),
|
||||||
|
"feta_args": ("FETAARGS", ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,7 +621,7 @@ class CogVideoSampler:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
||||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
|
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
|
||||||
mm.unload_all_models()
|
mm.unload_all_models()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
@ -722,6 +742,7 @@ class CogVideoSampler:
|
|||||||
tora=tora_trajectory if tora_trajectory is not None else None,
|
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||||
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
|
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
|
||||||
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
|
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
|
||||||
|
feta_args=feta_args,
|
||||||
)
|
)
|
||||||
if not model["cpu_offloading"] and model["manual_offloading"]:
|
if not model["cpu_offloading"] and model["manual_offloading"]:
|
||||||
pipe.transformer.to(offload_device)
|
pipe.transformer.to(offload_device)
|
||||||
@ -960,6 +981,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoLatentPreview": CogVideoLatentPreview,
|
"CogVideoLatentPreview": CogVideoLatentPreview,
|
||||||
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
|
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
|
||||||
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
|
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
|
||||||
|
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"CogVideoSampler": "CogVideo Sampler",
|
"CogVideoSampler": "CogVideo Sampler",
|
||||||
@ -976,4 +998,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoLatentPreview": "CogVideo LatentPreview",
|
"CogVideoLatentPreview": "CogVideo LatentPreview",
|
||||||
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
|
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
|
||||||
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
|
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
|
||||||
|
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from diffusers.loaders import CogVideoXLoraLoaderMixin
|
|||||||
|
|
||||||
from .embeddings import get_3d_rotary_pos_embed
|
from .embeddings import get_3d_rotary_pos_embed
|
||||||
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||||
|
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight
|
||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
@ -351,6 +352,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
tora: Optional[dict] = None,
|
tora: Optional[dict] = None,
|
||||||
image_cond_start_percent: float = 0.0,
|
image_cond_start_percent: float = 0.0,
|
||||||
image_cond_end_percent: float = 1.0,
|
image_cond_end_percent: float = 1.0,
|
||||||
|
feta_args: Optional[dict] = None,
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -573,7 +575,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
else:
|
else:
|
||||||
controlnet_states = None
|
controlnet_states = None
|
||||||
control_weights= None
|
control_weights= None
|
||||||
|
# 9. Tora
|
||||||
if tora is not None:
|
if tora is not None:
|
||||||
trajectory_length = tora["video_flow_features"].shape[1]
|
trajectory_length = tora["video_flow_features"].shape[1]
|
||||||
logger.info(f"Tora trajectory length: {trajectory_length}")
|
logger.info(f"Tora trajectory length: {trajectory_length}")
|
||||||
@ -585,16 +587,32 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
||||||
|
|
||||||
|
if feta_args is not None:
|
||||||
|
set_enhance_weight(feta_args["weight"])
|
||||||
|
feta_start_percent = feta_args["start_percent"]
|
||||||
|
feta_end_percent = feta_args["end_percent"]
|
||||||
|
enable_enhance()
|
||||||
|
else:
|
||||||
|
disable_enhance()
|
||||||
|
|
||||||
|
# 11. Denoising loop
|
||||||
from .latent_preview import prepare_callback
|
from .latent_preview import prepare_callback
|
||||||
callback = prepare_callback(self.transformer, num_inference_steps)
|
callback = prepare_callback(self.transformer, num_inference_steps)
|
||||||
|
|
||||||
# 9. Denoising loop
|
|
||||||
comfy_pbar = ProgressBar(len(timesteps))
|
comfy_pbar = ProgressBar(len(timesteps))
|
||||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||||
old_pred_original_sample = None # for DPM-solver++
|
old_pred_original_sample = None # for DPM-solver++
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
current_step_percentage = i / num_inference_steps
|
||||||
|
|
||||||
|
if feta_args is not None:
|
||||||
|
if feta_start_percent <= current_step_percentage <= feta_end_percent:
|
||||||
|
enable_enhance()
|
||||||
|
else:
|
||||||
|
disable_enhance()
|
||||||
# region context schedule sampling
|
# region context schedule sampling
|
||||||
if use_context_schedule:
|
if use_context_schedule:
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
@ -609,8 +627,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latent_model_input.shape[0])
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
current_step_percentage = i / num_inference_steps
|
|
||||||
|
|
||||||
# use same rotary embeddings for all context windows
|
# use same rotary embeddings for all context windows
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||||
@ -720,8 +736,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
current_step_percentage = i / num_inference_steps
|
|
||||||
|
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
|
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
|
||||||
latent_image_input = torch.zeros_like(latent_model_input)
|
latent_image_input = torch.zeros_like(latent_model_input)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user