support Tora for Fun -models

This commit is contained in:
kijai 2024-10-29 10:44:09 +02:00
parent 5fd4f67b14
commit 5b4819ba65
6 changed files with 208 additions and 44 deletions

View File

@ -528,6 +528,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
context_stride: Optional[int] = None, context_stride: Optional[int] = None,
context_overlap: Optional[int] = None, context_overlap: Optional[int] = None,
freenoise: Optional[bool] = True, freenoise: Optional[bool] = True,
tora: Optional[dict] = None,
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -720,7 +721,13 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
if self.transformer.config.use_rotary_positional_embeddings if self.transformer.config.use_rotary_positional_embeddings
else None else None
) )
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
if tora is not None:
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++ # for DPM-solver++
@ -910,6 +917,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
return_dict=False, return_dict=False,
control_latents=current_control_latents, control_latents=current_control_latents,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0] )[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()

View File

@ -610,6 +610,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
context_stride: Optional[int] = None, context_stride: Optional[int] = None,
context_overlap: Optional[int] = None, context_overlap: Optional[int] = None,
freenoise: Optional[bool] = True, freenoise: Optional[bool] = True,
tora: Optional[dict] = None,
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -889,6 +890,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
if self.transformer.config.use_rotary_positional_embeddings if self.transformer.config.use_rotary_positional_embeddings
else None else None
) )
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
if tora is not None:
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@ -1061,6 +1069,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
current_step_percentage = i / num_inference_steps
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
@ -1069,6 +1079,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
return_dict=False, return_dict=False,
inpaint_latents=inpaint_latents, inpaint_latents=inpaint_latents,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0] )[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()

View File

@ -34,6 +34,7 @@ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from einops import rearrange
try: try:
from sageattention import sageattn from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True SAGEATTN_IS_AVAVILABLE = True
@ -42,6 +43,23 @@ except:
logger.info("sageattn not found, using sdpa") logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAVILABLE = False SAGEATTN_IS_AVAVILABLE = False
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
B, C, H, W = tensor.size()
radius = min(H, W) // 5
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
center_x, center_y = W // 2, H // 2
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
high_freq_mask = ~low_freq_mask
low_freq_fft = tensor_fft_shifted * low_freq_mask
high_freq_fft = tensor_fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
class CogVideoXAttnProcessor2_0: class CogVideoXAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@ -315,6 +333,11 @@ class CogVideoXBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
@ -322,7 +345,39 @@ class CogVideoXBlock(nn.Module):
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb hidden_states, encoder_hidden_states, temb
) )
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
#fastercache
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
# attention # attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1( attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
@ -497,6 +552,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.fuser_list = None
self.use_fastercache = False
self.fastercache_counter = 0
self.fastercache_start_step = 15
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value self.gradient_checkpointing = value
@ -609,6 +673,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
inpaint_latents: Optional[torch.Tensor] = None, inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None, control_latents: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_features: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
): ):
batch_size, num_frames, channels, height, width = hidden_states.shape batch_size, num_frames, channels, height, width = hidden_states.shape
@ -649,50 +714,101 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states = hidden_states[:, :text_seq_length] encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:] hidden_states = hidden_states[:, text_seq_length:]
# 4. Transformer blocks if self.use_fastercache:
for i, block in enumerate(self.transformer_blocks): self.fastercache_counter+=1
if self.training and self.gradient_checkpointing: if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
# 4. Transformer blocks
def create_custom_forward(module): for i, block in enumerate(self.transformer_blocks):
def custom_forward(*inputs): hidden_states, encoder_hidden_states = block(
return module(*inputs) hidden_states=hidden_states[:1],
encoder_hidden_states=encoder_hidden_states[:1],
return custom_forward temb=emb[:1],
image_rotary_emb=image_rotary_emb,
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( fuser = self.fuser_list[i] if self.fuser_list is not None else None,
create_custom_forward(block), fastercache_counter = self.fastercache_counter,
hidden_states, fastercache_device = self.fastercache_device
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
) )
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else: else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb[:1])
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(1, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
#lf_step = 40
#hf_step = 30
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
self.delta_hf = self.delta_hf * 1.1
new_hf_uc = self.delta_hf + hf_c
new_lf_uc = self.delta_lf + lf_c
combine_uc = new_lf_uc + new_hf_uc
combined_fft = torch.fft.ifftshift(combine_uc)
recovered_uncond = torch.fft.ifft2(combined_fft).real
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
output = torch.cat([output, recovered_uncond])
else:
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=emb, temb=emb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
) )
if not self.config.use_rotary_positional_embeddings: if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B # CogVideoX-2B
hidden_states = self.norm_final(hidden_states) hidden_states = self.norm_final(hidden_states)
else: else:
# CogVideoX-5B # CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states) hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:] hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block # 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond)
lf_uc, hf_uc = fft(uncond)
self.delta_lf = lf_uc - lf_c
self.delta_hf = hf_uc - hf_c
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict: if not return_dict:
return (output,) return (output,)

View File

@ -662,7 +662,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states[:1], encoder_hidden_states=encoder_hidden_states[:1],
temb=emb[:1], temb=emb[:1],
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter, fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device fastercache_device = self.fastercache_device

View File

@ -1480,6 +1480,8 @@ class CogVideoXFunSampler:
"opt_empty_latent": ("LATENT",), "opt_empty_latent": ("LATENT",),
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}), "noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
"context_options": ("COGCONTEXT", ), "context_options": ("COGCONTEXT", ),
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS",),
}, },
} }
@ -1489,7 +1491,7 @@ class CogVideoXFunSampler:
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler, def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler,
start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None): start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None, fastercache=None, tora_trajectory=None):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"] pipe = pipeline["pipe"]
@ -1538,6 +1540,20 @@ class CogVideoXFunSampler:
else: else:
context_frames, context_stride, context_overlap = None, None, None context_frames, context_stride, context_overlap = None, None, None
if tora_trajectory is not None:
pipe.transformer.fuser_list = tora_trajectory["fuser_list"]
if fastercache is not None:
pipe.transformer.use_fastercache = True
pipe.transformer.fastercache_counter = 0
pipe.transformer.fastercache_start_step = fastercache["start_step"]
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
pipe.transformer.fastercache_device = fastercache["cache_device"]
else:
pipe.transformer.use_fastercache = False
pipe.transformer.fastercache_counter = 0
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
@ -1564,7 +1580,8 @@ class CogVideoXFunSampler:
context_frames=context_frames, context_frames=context_frames,
context_stride= context_stride, context_stride= context_stride,
context_overlap= context_overlap, context_overlap= context_overlap,
freenoise=context_options["freenoise"] if context_options is not None else None freenoise=context_options["freenoise"] if context_options is not None else None,
tora=tora_trajectory if tora_trajectory is not None else None,
) )
#if not pipeline["cpu_offloading"]: #if not pipeline["cpu_offloading"]:
# pipe.transformer.to(offload_device) # pipe.transformer.to(offload_device)

View File

@ -287,11 +287,21 @@ class MGF(nn.Module):
gamma_flow = self.flow_gamma_spatial(flow) gamma_flow = self.flow_gamma_spatial(flow)
beta_flow = self.flow_beta_spatial(flow) beta_flow = self.flow_beta_spatial(flow)
_, _, hh, wh = beta_flow.shape _, _, hh, wh = beta_flow.shape
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T) if gamma_flow.shape[0] == 1: # Check if batch size is 1
gamma_flow = self.flow_gamma_temporal(gamma_flow) gamma_flow = rearrange(gamma_flow, "b c h w -> b c (h w)")
beta_flow = self.flow_beta_temporal(beta_flow) beta_flow = rearrange(beta_flow, "b c h w -> b c (h w)")
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) gamma_flow = self.flow_gamma_temporal(gamma_flow)
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) beta_flow = self.flow_beta_temporal(beta_flow)
gamma_flow = rearrange(gamma_flow, "b c (h w) -> b c h w", h=hh, w=wh)
beta_flow = rearrange(beta_flow, "b c (h w) -> b c h w", h=hh, w=wh)
else:
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
gamma_flow = self.flow_gamma_temporal(gamma_flow)
beta_flow = self.flow_beta_temporal(beta_flow)
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow
return h return h