support Tora for Fun -models

This commit is contained in:
kijai 2024-10-29 10:44:09 +02:00
parent 5fd4f67b14
commit 5b4819ba65
6 changed files with 208 additions and 44 deletions

View File

@ -528,6 +528,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
context_stride: Optional[int] = None,
context_overlap: Optional[int] = None,
freenoise: Optional[bool] = True,
tora: Optional[dict] = None,
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@ -720,7 +721,13 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
if self.transformer.config.use_rotary_positional_embeddings
else None
)
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
if tora is not None:
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
@ -910,6 +917,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=current_control_latents,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0]
noise_pred = noise_pred.float()

View File

@ -610,6 +610,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
context_stride: Optional[int] = None,
context_overlap: Optional[int] = None,
freenoise: Optional[bool] = True,
tora: Optional[dict] = None,
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@ -889,6 +890,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
if self.transformer.config.use_rotary_positional_embeddings
else None
)
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
if tora is not None:
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@ -1061,6 +1069,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
current_step_percentage = i / num_inference_steps
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
@ -1069,6 +1079,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
image_rotary_emb=image_rotary_emb,
return_dict=False,
inpaint_latents=inpaint_latents,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0]
noise_pred = noise_pred.float()

View File

@ -34,6 +34,7 @@ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from einops import rearrange
try:
from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True
@ -42,6 +43,23 @@ except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAVILABLE = False
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
B, C, H, W = tensor.size()
radius = min(H, W) // 5
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
center_x, center_y = W // 2, H // 2
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
high_freq_mask = ~low_freq_mask
low_freq_fft = tensor_fft_shifted * low_freq_mask
high_freq_fft = tensor_fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@ -315,6 +333,11 @@ class CogVideoXBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@ -322,7 +345,39 @@ class CogVideoXBlock(nn.Module):
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
#fastercache
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
@ -497,6 +552,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
self.fuser_list = None
self.use_fastercache = False
self.fastercache_counter = 0
self.fastercache_start_step = 15
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@ -609,6 +673,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_features: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
@ -649,50 +714,101 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
if self.use_fastercache:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states[:1],
encoder_hidden_states=encoder_hidden_states[:1],
temb=emb[:1],
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb[:1])
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(1, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
#lf_step = 40
#hf_step = 30
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
self.delta_hf = self.delta_hf * 1.1
new_hf_uc = self.delta_hf + hf_c
new_lf_uc = self.delta_lf + lf_c
combine_uc = new_lf_uc + new_hf_uc
combined_fft = torch.fft.ifftshift(combine_uc)
recovered_uncond = torch.fft.ifft2(combined_fft).real
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
output = torch.cat([output, recovered_uncond])
else:
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond)
lf_uc, hf_uc = fft(uncond)
self.delta_lf = lf_uc - lf_c
self.delta_hf = hf_uc - hf_c
if not return_dict:
return (output,)

View File

@ -662,7 +662,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states[:1],
temb=emb[:1],
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device

View File

@ -1480,6 +1480,8 @@ class CogVideoXFunSampler:
"opt_empty_latent": ("LATENT",),
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
"context_options": ("COGCONTEXT", ),
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS",),
},
}
@ -1489,7 +1491,7 @@ class CogVideoXFunSampler:
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler,
start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None):
start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None, fastercache=None, tora_trajectory=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
@ -1538,6 +1540,20 @@ class CogVideoXFunSampler:
else:
context_frames, context_stride, context_overlap = None, None, None
if tora_trajectory is not None:
pipe.transformer.fuser_list = tora_trajectory["fuser_list"]
if fastercache is not None:
pipe.transformer.use_fastercache = True
pipe.transformer.fastercache_counter = 0
pipe.transformer.fastercache_start_step = fastercache["start_step"]
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
pipe.transformer.fastercache_device = fastercache["cache_device"]
else:
pipe.transformer.use_fastercache = False
pipe.transformer.fastercache_counter = 0
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
@ -1564,7 +1580,8 @@ class CogVideoXFunSampler:
context_frames=context_frames,
context_stride= context_stride,
context_overlap= context_overlap,
freenoise=context_options["freenoise"] if context_options is not None else None
freenoise=context_options["freenoise"] if context_options is not None else None,
tora=tora_trajectory if tora_trajectory is not None else None,
)
#if not pipeline["cpu_offloading"]:
# pipe.transformer.to(offload_device)

View File

@ -287,11 +287,21 @@ class MGF(nn.Module):
gamma_flow = self.flow_gamma_spatial(flow)
beta_flow = self.flow_beta_spatial(flow)
_, _, hh, wh = beta_flow.shape
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
gamma_flow = self.flow_gamma_temporal(gamma_flow)
beta_flow = self.flow_beta_temporal(beta_flow)
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
if gamma_flow.shape[0] == 1: # Check if batch size is 1
gamma_flow = rearrange(gamma_flow, "b c h w -> b c (h w)")
beta_flow = rearrange(beta_flow, "b c h w -> b c (h w)")
gamma_flow = self.flow_gamma_temporal(gamma_flow)
beta_flow = self.flow_beta_temporal(beta_flow)
gamma_flow = rearrange(gamma_flow, "b c (h w) -> b c h w", h=hh, w=wh)
beta_flow = rearrange(beta_flow, "b c (h w) -> b c h w", h=hh, w=wh)
else:
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
gamma_flow = self.flow_gamma_temporal(gamma_flow)
beta_flow = self.flow_beta_temporal(beta_flow)
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow
return h