add basic spatiotemporal guidance impl

This commit is contained in:
zhilemann 2024-12-22 03:02:23 +03:00 committed by GitHub
parent f16d38a5d2
commit 0ea77bc63f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 213 additions and 18 deletions

View File

@ -113,6 +113,8 @@ class CogVideoXAttnProcessor2_0:
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.attention_mode = attention_mode
self.attn_func = attn_func
self.stg_mode: str = None
def __call__(
self,
attn: Attention,
@ -121,6 +123,16 @@ class CogVideoXAttnProcessor2_0:
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.stg_mode == "STG-R":
return hidden_states, encoder_hidden_states
if self.stg_mode == "STG-A":
hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0)
encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0)
hidden_states = torch.cat([hidden_states_uncond, hidden_states], dim=0)
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states], dim=0)
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@ -166,9 +178,9 @@ class CogVideoXAttnProcessor2_0:
#feta
if is_enhance_enabled():
feta_scores = get_feta_scores(attn, query, key, head_dim, text_seq_length)
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
if self.attention_mode != "comfy":
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@ -184,6 +196,78 @@ class CogVideoXAttnProcessor2_0:
if is_enhance_enabled():
hidden_states *= feta_scores
if self.stg_mode == "STG-A":
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
if not "fused" in self.attention_mode:
query_perturb = attn.to_q(hidden_states)
key_perturb = attn.to_k(hidden_states)
value_perturb = attn.to_v(hidden_states)
else:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1)
inner_dim = key_perturb.shape[-1]
head_dim = inner_dim // attn.heads
query_perturb = query_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_perturb = key_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_perturb = value_perturb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query_perturb = attn.norm_q(query_perturb)
if attn.norm_k is not None:
key_perturb = attn.norm_k(key_perturb)
# Apply RoPE if needed
if image_rotary_emb is not None:
query_perturb[:, :, text_seq_length:] = apply_rotary_emb(query_perturb[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key_perturb[:, :, text_seq_length:] = apply_rotary_emb(key_perturb[:, :, text_seq_length:], image_rotary_emb)
full_seq_length = query_perturb.size(2)
identity_block_size = full_seq_length - text_seq_length
full_mask = torch.zeros((full_seq_length, full_seq_length), device=query_ptb.device, dtype=query_ptb.dtype)
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
full_mask = full_mask.unsqueeze(0).unsqueeze(0)
hidden_states_perturb = self.attn_func(
query_perturb, key_perturb, value_perturb, attn_mask=full_mask, is_causal=False
)
if self.attention_mode != "comfy":
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0)
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_perturb], dim=0)
return hidden_states, encoder_hidden_states
#region Blocks
@ -731,4 +815,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -1,4 +1,4 @@
import os
import os, re
import torch
import json
from einops import rearrange
@ -96,6 +96,31 @@ class CogVideoContextOptions:
return (context_options,)
class CogVideoSTG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mode": (["STG-A", "STG-R"],),
"blocks": ("STRING", {"default": "30", "tooltip": "Block index to apply STG"}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Recommended values are ≤2.0"}),
"rescale": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Recommended values are ≤1.0"}),
},
}
RETURN_TYPES = ("STGARGS",)
RETURN_NAMES = ("stg_args",)
FUNCTION = "setargs"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "https://github.com/junhahyung/STGuidance"
def setargs(self, mode, blocks, scale, rescale):
return ({
"stg_mode": mode,
"stg_layers_idx": [int(x) for x in re.findall("\d+", blocks)],
"stg_scale": scale,
"stg_rescaling": rescale
}, )
class CogVideoTransformerEdit:
@classmethod
def INPUT_TYPES(s):
@ -612,6 +637,7 @@ class CogVideoSampler:
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ),
"feta_args": ("FETAARGS", ),
"stg_args": ("STGARGS", ),
}
}
@ -621,7 +647,8 @@ class CogVideoSampler:
CATEGORY = "CogVideoWrapper"
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None,
tora_trajectory=None, fastercache=None, feta_args=None, stg_args=None):
mm.unload_all_models()
mm.soft_empty_cache()
@ -743,6 +770,7 @@ class CogVideoSampler:
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
feta_args=feta_args,
**(stg_args if stg_args else {}),
)
if not model["cpu_offloading"] and model["manual_offloading"]:
pipe.transformer.to(offload_device)
@ -973,6 +1001,7 @@ NODE_CLASS_MAPPINGS = {
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
"CogVideoTransformerEdit": CogVideoTransformerEdit,
"CogVideoContextOptions": CogVideoContextOptions,
"CogVideoSTG": CogVideoSTG,
"CogVideoControlNet": CogVideoControlNet,
"ToraEncodeTrajectory": ToraEncodeTrajectory,
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
@ -991,6 +1020,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
"CogVideoContextOptions": "CogVideo Context Options",
"CogVideoSTG": "CogVideo Spatiotemporal Guidance",
"ToraEncodeTrajectory": "Tora Encode Trajectory",
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
"CogVideoXFasterCache": "CogVideoX FasterCache",

View File

@ -316,6 +316,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def do_spatiotemporal_guidance(self):
return self._stg_scale > 0
@property
def num_timesteps(self):
return self._num_timesteps
@ -324,6 +328,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def interrupt(self):
return self._interrupt
def extract_attn_layers(self):
for key, mod in self.transformer.named_modules():
if "attn1" in key and "to" not in key and "add" not in key and "norm" not in key:
yield mod
@torch.no_grad()
def __call__(
self,
@ -353,6 +362,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
image_cond_start_percent: float = 0.0,
image_cond_end_percent: float = 1.0,
feta_args: Optional[dict] = None,
stg_mode: Optional[str] = None, # "STG-A",
stg_layers_idx: Optional[List[int]] = [], # [30],
stg_scale: Optional[float] = 0.0, # 4.0
stg_rescaling: Optional[float] = None, # 0.7
):
"""
@ -409,9 +422,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
prompt_embeds,
negative_prompt_embeds,
)
self._stg_scale = stg_scale
self._guidance_scale = guidance_scale
self._interrupt = False
for idx, mod in enumerate(self.extract_attn_layers()):
mod.stg_mode = stg_mode if idx in stg_layers_idx else None
# 2. Default call parameters
batch_size = prompt_embeds.shape[0]
@ -421,8 +438,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale[0] > 1.0
if do_classifier_free_guidance:
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
prompt_embeds = prompt_embeds.to(self.vae_dtype)
# 4. Prepare timesteps
@ -615,13 +635,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
disable_enhance()
# region context schedule sampling
if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
latent_model_input = torch.cat([latents] * 2)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
latent_model_input = torch.cat([latents] * 3)
else:
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
counter = torch.zeros_like(latent_model_input)
noise_pred = torch.zeros_like(latent_model_input)
if image_cond_latents is not None:
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
latent_image_input = torch.cat([image_cond_latents] * 2)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
latent_image_input = torch.cat([image_cond_latents] * 3)
else:
latent_image_input = image_cond_latents
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@ -708,9 +740,19 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
noise_pred = noise_pred.float()
noise_pred /= counter
if do_classifier_free_guidance:
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) \
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
if stg_rescaling:
factor = noise_pred_text.std() / noise_pred.std()
factor = stg_rescaling * factor + (1 - stg_rescaling)
noise_pred = noise_pred * factor
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
@ -733,25 +775,54 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# region sampling
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
latent_model_input = torch.cat([latents] * 2)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
latent_model_input = torch.cat([latents] * 3)
else:
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if image_cond_latents is not None:
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
latent_image_input = torch.zeros_like(latent_model_input)
else:
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
latent_image_input = torch.cat([image_cond_latents] * 2)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
latent_image_input = torch.cat([image_cond_latents] * 3)
else:
latent_image_input = image_cond_latents
if fun_mask is not None: #for fun img2vid and interpolation
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
fun_inpaint_mask = torch.cat([fun_mask] * 2)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
fun_inpaint_mask = torch.cat([fun_mask] * 3)
else:
fun_inpaint_mask = fun_mask
masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2)
latent_model_input = torch.cat([latent_model_input, masks_input], dim=2)
else:
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
else: # for Fun inpaint vid2vid
if fun_mask is not None:
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents
fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
fun_inpaint_mask = torch.cat([fun_mask] * 2)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
fun_inpaint_mask = torch.cat([fun_mask] * 3)
else:
fun_inpaint_mask = fun_mask
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 3)
else:
fun_inpaint_masked_video_latents = fun_masked_video_latents
latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@ -791,11 +862,21 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
self._guidance_scale[i] = 1 + guidance_scale[i] * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
if do_classifier_free_guidance and not self.do_spatiotemporal_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
elif do_classifier_free_guidance and self.do_spatiotemporal_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) \
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
if stg_rescaling:
factor = noise_pred_text.std() / noise_pred.std()
factor = stg_rescaling * factor + (1 - stg_rescaling)
noise_pred = noise_pred * factor
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents.to(self.vae_dtype), **extra_step_kwargs, return_dict=False)[0]
@ -822,4 +903,4 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# Offload all models
self.maybe_free_model_hooks()
return latents
return latents