initial experimental FasterCache support for 2b models

This commit is contained in:
kijai 2024-10-28 21:02:10 +02:00
parent 21025c4742
commit e9fc26b5e3
2 changed files with 211 additions and 39 deletions

View File

@ -43,6 +43,23 @@ except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAILABLE = False
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
B, C, H, W = tensor.size()
radius = min(H, W) // 5
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
center_x, center_y = W // 2, H // 2
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
high_freq_mask = ~low_freq_mask
low_freq_fft = tensor_fft_shifted * low_freq_mask
high_freq_fft = tensor_fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@ -192,6 +209,7 @@ class FusedCogVideoXAttnProcessor2_0:
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
@ -270,7 +288,9 @@ class CogVideoXBlock(nn.Module):
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.cached_hidden_states = []
self.cached_encoder_hidden_states = []
def forward(
self,
hidden_states: torch.Tensor,
@ -279,14 +299,15 @@ class CogVideoXBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
@ -294,14 +315,41 @@ class CogVideoXBlock(nn.Module):
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
del h, fuser
#fastercache
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0]>=norm_hidden_states.shape[0]:
attn_hidden_states = (
self.cached_hidden_states[1][:norm_hidden_states.shape[0]] +
(self.cached_hidden_states[1][:norm_hidden_states.shape[0]] -
self.cached_hidden_states[0][:norm_hidden_states.shape[0]])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] +
(self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] -
self.cached_encoder_hidden_states[0][:norm_hidden_states.shape[0]])
*0.3
).to(norm_hidden_states.device, non_blocking=True)
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if fastercache_counter==fastercache_start_step:
self.cached_hidden_states = [
attn_hidden_states.to(fastercache_device),
attn_hidden_states.to(fastercache_device)
]
self.cached_encoder_hidden_states = [
attn_encoder_hidden_states.to(fastercache_device),
attn_encoder_hidden_states.to(fastercache_device)
]
elif fastercache_counter>fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
@ -471,6 +519,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
self.fuser_list = None
self.use_fastercache = False
self.fastercache_counter = 0
self.fastercache_start_step = 15
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@ -606,18 +660,83 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if self.use_fastercache:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states[:1],
encoder_hidden_states=encoder_hidden_states[:1],
temb=emb[:1],
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb[:1])
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
)
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
#lf_step = 40
#hf_step = 30
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
self.delta_hf = self.delta_hf * 1.1
new_hf_uc = self.delta_hf + hf_c
new_lf_uc = self.delta_lf + lf_c
combine_uc = new_lf_uc + new_hf_uc
combined_fft = torch.fft.ifftshift(combine_uc)
recovered_uncond = torch.fft.ifft2(combined_fft).real
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
output = torch.cat([output, recovered_uncond])
else:
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
@ -628,28 +747,40 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if self.fastercache_counter>=16:
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond)
lf_uc, hf_uc = fft(uncond)
self.delta_lf = lf_uc - lf_c
self.delta_hf = hf_uc - hf_c
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -1253,7 +1253,34 @@ class ToraEncodeOpticalFlow:
return (tora, )
class CogVideoXFasterCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"start_step": ("INT", {"default": 15, "min": 0, "max": 1024, "step": 1}),
"hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}),
"lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}),
"cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
},
}
RETURN_TYPES = ("FASTERCACHEARGS",)
RETURN_NAMES = ("fastercache", )
FUNCTION = "args"
CATEGORY = "CogVideoWrapper"
def args(self, start_step, hf_step, lf_step, cache_device):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
fastercache = {
"start_step" : start_step,
"hf_step" : hf_step,
"lf_step" : lf_step,
"cache_device" : device if cache_device == "main_device" else offload_device
}
return (fastercache,)
class CogVideoSampler:
@classmethod
def INPUT_TYPES(s):
@ -1280,6 +1307,7 @@ class CogVideoSampler:
"context_options": ("COGCONTEXT", ),
"controlnet": ("COGVIDECONTROLNET",),
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ),
}
}
@ -1289,7 +1317,7 @@ class CogVideoSampler:
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None):
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
mm.soft_empty_cache()
base_path = pipeline["base_path"]
@ -1328,6 +1356,17 @@ class CogVideoSampler:
padding = torch.zeros((negative.shape[0], target_length - negative.shape[1], negative.shape[2]), device=negative.device)
negative = torch.cat((negative, padding), dim=1)
if fastercache is not None:
pipe.transformer.use_fastercache = True
pipe.transformer.fastercache_counter = 0
pipe.transformer.fastercache_start_step = fastercache["start_step"]
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
pipe.transformer.fastercache_device = fastercache["cache_device"]
else:
pipe.transformer.use_fastercache = False
pipe.transformer.fastercache_counter = 0
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
with autocast_context:
@ -1901,6 +1940,7 @@ NODE_CLASS_MAPPINGS = {
"ToraEncodeTrajectory": ToraEncodeTrajectory,
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
"CogVideoXFasterCache": CogVideoXFasterCache
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1924,4 +1964,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ToraEncodeTrajectory": "Tora Encode Trajectory",
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
"DownloadAndLoadToraModel": "(Down)load Tora Model",
"CogVideoXFasterCache": "CogVideo XFasterCache"
}