mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
add basic spatiotemporal guidance impl
This commit is contained in:
parent
f16d38a5d2
commit
0ea77bc63f
@ -113,6 +113,8 @@ class CogVideoXAttnProcessor2_0:
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
self.attention_mode = attention_mode
|
||||
self.attn_func = attn_func
|
||||
self.stg_mode: str = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
@ -121,6 +123,16 @@ class CogVideoXAttnProcessor2_0:
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.stg_mode == "STG-R":
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
if self.stg_mode == "STG-A":
|
||||
hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0)
|
||||
encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_uncond, hidden_states], dim=0)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states], dim=0)
|
||||
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
@ -184,6 +196,78 @@ class CogVideoXAttnProcessor2_0:
|
||||
if is_enhance_enabled():
|
||||
hidden_states *= feta_scores
|
||||
|
||||
if self.stg_mode == "STG-A":
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
||||
|
||||
if not "fused" in self.attention_mode:
|
||||
query_perturb = attn.to_q(hidden_states)
|
||||
key_perturb = attn.to_k(hidden_states)
|
||||
value_perturb = attn.to_v(hidden_states)
|
||||
else:
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key_perturb.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query_perturb = query_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_perturb = key_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_perturb = value_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query_perturb = attn.norm_q(query_perturb)
|
||||
if attn.norm_k is not None:
|
||||
key_perturb = attn.norm_k(key_perturb)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query_perturb[:, :, text_seq_length:] = apply_rotary_emb(query_perturb[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key_perturb[:, :, text_seq_length:] = apply_rotary_emb(key_perturb[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
full_seq_length = query_perturb.size(2)
|
||||
identity_block_size = full_seq_length - text_seq_length
|
||||
|
||||
full_mask = torch.zeros((full_seq_length, full_seq_length), device=query_ptb.device, dtype=query_ptb.dtype)
|
||||
|
||||
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
||||
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
||||
|
||||
full_mask = full_mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
hidden_states_perturb = self.attn_func(
|
||||
query_perturb, key_perturb, value_perturb, attn_mask=full_mask, is_causal=False
|
||||
)
|
||||
|
||||
if self.attention_mode != "comfy":
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_perturb], dim=0)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
#region Blocks
|
||||
|
||||
34
nodes.py
34
nodes.py
@ -1,4 +1,4 @@
|
||||
import os
|
||||
import os, re
|
||||
import torch
|
||||
import json
|
||||
from einops import rearrange
|
||||
@ -96,6 +96,31 @@ class CogVideoContextOptions:
|
||||
|
||||
return (context_options,)
|
||||
|
||||
class CogVideoSTG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mode": (["STG-A", "STG-R"],),
|
||||
"blocks": ("STRING", {"default": "30", "tooltip": "Block index to apply STG"}),
|
||||
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Recommended values are ≤2.0"}),
|
||||
"rescale": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Recommended values are ≤1.0"}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("STGARGS",)
|
||||
RETURN_NAMES = ("stg_args",)
|
||||
FUNCTION = "setargs"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "https://github.com/junhahyung/STGuidance"
|
||||
|
||||
def setargs(self, mode, blocks, scale, rescale):
|
||||
return ({
|
||||
"stg_mode": mode,
|
||||
"stg_layers_idx": [int(x) for x in re.findall("\d+", blocks)],
|
||||
"stg_scale": scale,
|
||||
"stg_rescaling": rescale
|
||||
}, )
|
||||
|
||||
class CogVideoTransformerEdit:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -612,6 +637,7 @@ class CogVideoSampler:
|
||||
"tora_trajectory": ("TORAFEATURES", ),
|
||||
"fastercache": ("FASTERCACHEARGS", ),
|
||||
"feta_args": ("FETAARGS", ),
|
||||
"stg_args": ("STGARGS", ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -621,7 +647,8 @@ class CogVideoSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None,
|
||||
tora_trajectory=None, fastercache=None, feta_args=None, stg_args=None):
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
@ -743,6 +770,7 @@ class CogVideoSampler:
|
||||
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
|
||||
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
|
||||
feta_args=feta_args,
|
||||
**(stg_args if stg_args else {}),
|
||||
)
|
||||
if not model["cpu_offloading"] and model["manual_offloading"]:
|
||||
pipe.transformer.to(offload_device)
|
||||
@ -973,6 +1001,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
||||
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
||||
"CogVideoContextOptions": CogVideoContextOptions,
|
||||
"CogVideoSTG": CogVideoSTG,
|
||||
"CogVideoControlNet": CogVideoControlNet,
|
||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
|
||||
@ -991,6 +1020,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
||||
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
||||
"CogVideoContextOptions": "CogVideo Context Options",
|
||||
"CogVideoSTG": "CogVideo Spatiotemporal Guidance",
|
||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
||||
"CogVideoXFasterCache": "CogVideoX FasterCache",
|
||||
|
||||
@ -316,6 +316,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def do_spatiotemporal_guidance(self):
|
||||
return self._stg_scale > 0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@ -324,6 +328,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
def extract_attn_layers(self):
|
||||
for key, mod in self.transformer.named_modules():
|
||||
if "attn1" in key and "to" not in key and "add" not in key and "norm" not in key:
|
||||
yield mod
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@ -353,6 +362,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
image_cond_start_percent: float = 0.0,
|
||||
image_cond_end_percent: float = 1.0,
|
||||
feta_args: Optional[dict] = None,
|
||||
stg_mode: Optional[str] = None, # "STG-A",
|
||||
stg_layers_idx: Optional[List[int]] = [], # [30],
|
||||
stg_scale: Optional[float] = 0.0, # 4.0
|
||||
stg_rescaling: Optional[float] = None, # 0.7
|
||||
|
||||
):
|
||||
"""
|
||||
@ -409,9 +422,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._interrupt = False
|
||||
|
||||
for idx, mod in enumerate(self.extract_attn_layers()):
|
||||
mod.stg_mode = stg_mode if idx in stg_layers_idx else None
|
||||
|
||||
# 2. Default call parameters
|
||||
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
@ -421,8 +438,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale[0] > 1.0
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(self.vae_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
@ -615,13 +635,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
disable_enhance()
|
||||
# region context schedule sampling
|
||||
if use_context_schedule:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
counter = torch.zeros_like(latent_model_input)
|
||||
noise_pred = torch.zeros_like(latent_model_input)
|
||||
|
||||
if image_cond_latents is not None:
|
||||
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
latent_image_input = torch.cat([image_cond_latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
latent_image_input = torch.cat([image_cond_latents] * 3)
|
||||
else:
|
||||
latent_image_input = image_cond_latents
|
||||
|
||||
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@ -708,10 +740,20 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
noise_pred /= counter
|
||||
if do_classifier_free_guidance:
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) \
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
|
||||
if stg_rescaling:
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = stg_rescaling * factor + (1 - stg_rescaling)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
@ -733,25 +775,54 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
# region sampling
|
||||
else:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
if image_cond_latents is not None:
|
||||
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
|
||||
latent_image_input = torch.zeros_like(latent_model_input)
|
||||
else:
|
||||
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
latent_image_input = torch.cat([image_cond_latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
latent_image_input = torch.cat([image_cond_latents] * 3)
|
||||
else:
|
||||
latent_image_input = image_cond_latents
|
||||
|
||||
if fun_mask is not None: #for fun img2vid and interpolation
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 3)
|
||||
else:
|
||||
fun_inpaint_mask = fun_mask
|
||||
|
||||
masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2)
|
||||
latent_model_input = torch.cat([latent_model_input, masks_input], dim=2)
|
||||
else:
|
||||
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
||||
else: # for Fun inpaint vid2vid
|
||||
if fun_mask is not None:
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
|
||||
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents
|
||||
fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 3)
|
||||
else:
|
||||
fun_inpaint_mask = fun_mask
|
||||
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 3)
|
||||
else:
|
||||
fun_inpaint_masked_video_latents = fun_masked_video_latents
|
||||
|
||||
latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@ -792,10 +863,20 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) \
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
|
||||
if stg_rescaling:
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = stg_rescaling * factor + (1 - stg_rescaling)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents.to(self.vae_dtype), **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user