mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
initial experimental FasterCache support for 2b models
This commit is contained in:
parent
21025c4742
commit
e9fc26b5e3
@ -43,6 +43,23 @@ except:
|
||||
logger.info("sageattn not found, using sdpa")
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
def fft(tensor):
|
||||
tensor_fft = torch.fft.fft2(tensor)
|
||||
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
||||
B, C, H, W = tensor.size()
|
||||
radius = min(H, W) // 5
|
||||
|
||||
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
|
||||
center_x, center_y = W // 2, H // 2
|
||||
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
|
||||
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
|
||||
high_freq_mask = ~low_freq_mask
|
||||
|
||||
low_freq_fft = tensor_fft_shifted * low_freq_mask
|
||||
high_freq_fft = tensor_fft_shifted * high_freq_mask
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@ -192,6 +209,7 @@ class FusedCogVideoXAttnProcessor2_0:
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
|
||||
r"""
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
@ -270,7 +288,9 @@ class CogVideoXBlock(nn.Module):
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.cached_hidden_states = []
|
||||
self.cached_encoder_hidden_states = []
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -279,14 +299,15 @@ class CogVideoXBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
video_flow_feature: Optional[torch.Tensor] = None,
|
||||
fuser=None,
|
||||
fastercache_counter=0,
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0",
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# Tora Motion-guidance Fuser
|
||||
if video_flow_feature is not None:
|
||||
H, W = video_flow_feature.shape[-2:]
|
||||
@ -294,14 +315,41 @@ class CogVideoXBlock(nn.Module):
|
||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
|
||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||
del h, fuser
|
||||
del h, fuser
|
||||
|
||||
#fastercache
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0]>=norm_hidden_states.shape[0]:
|
||||
attn_hidden_states = (
|
||||
self.cached_hidden_states[1][:norm_hidden_states.shape[0]] +
|
||||
(self.cached_hidden_states[1][:norm_hidden_states.shape[0]] -
|
||||
self.cached_hidden_states[0][:norm_hidden_states.shape[0]])
|
||||
* 0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
attn_encoder_hidden_states = (
|
||||
self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] +
|
||||
(self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] -
|
||||
self.cached_encoder_hidden_states[0][:norm_hidden_states.shape[0]])
|
||||
*0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
if fastercache_counter==fastercache_start_step:
|
||||
self.cached_hidden_states = [
|
||||
attn_hidden_states.to(fastercache_device),
|
||||
attn_hidden_states.to(fastercache_device)
|
||||
]
|
||||
self.cached_encoder_hidden_states = [
|
||||
attn_encoder_hidden_states.to(fastercache_device),
|
||||
attn_encoder_hidden_states.to(fastercache_device)
|
||||
]
|
||||
elif fastercache_counter>fastercache_start_step:
|
||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
@ -471,6 +519,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.fuser_list = None
|
||||
self.use_fastercache = False
|
||||
self.fastercache_counter = 0
|
||||
self.fastercache_start_step = 15
|
||||
self.fastercache_lf_step = 40
|
||||
self.fastercache_hf_step = 30
|
||||
self.fastercache_device = "cuda"
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
@ -606,18 +660,83 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
if self.use_fastercache:
|
||||
self.fastercache_counter+=1
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states[:1],
|
||||
encoder_hidden_states=encoder_hidden_states[:1],
|
||||
temb=emb[:1],
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
controlnet_block_weight = 1.0
|
||||
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
||||
controlnet_block_weight = controlnet_weights[i]
|
||||
elif isinstance(controlnet_weights, (float, int)):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 4. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb[:1])
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
# Note: we use `-1` instead of `channels`:
|
||||
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
||||
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
)
|
||||
(bb, tt, cc, hh, ww) = output.shape
|
||||
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
lf_c, hf_c = fft(cond.float())
|
||||
#lf_step = 40
|
||||
#hf_step = 30
|
||||
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||
self.delta_lf = self.delta_lf * 1.1
|
||||
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||
self.delta_hf = self.delta_hf * 1.1
|
||||
|
||||
new_hf_uc = self.delta_hf + hf_c
|
||||
new_lf_uc = self.delta_lf + lf_c
|
||||
|
||||
combine_uc = new_lf_uc + new_hf_uc
|
||||
combined_fft = torch.fft.ifftshift(combine_uc)
|
||||
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
||||
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
output = torch.cat([output, recovered_uncond])
|
||||
else:
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
@ -628,28 +747,40 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
# 4. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 4. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
# 5. Unpatchify
|
||||
# Note: we use `-1` instead of `channels`:
|
||||
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
||||
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
# 5. Unpatchify
|
||||
# Note: we use `-1` instead of `channels`:
|
||||
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
||||
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
if self.fastercache_counter>=16:
|
||||
(bb, tt, cc, hh, ww) = output.shape
|
||||
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||
|
||||
lf_c, hf_c = fft(cond)
|
||||
lf_uc, hf_uc = fft(uncond)
|
||||
|
||||
self.delta_lf = lf_uc - lf_c
|
||||
self.delta_hf = hf_uc - hf_c
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
43
nodes.py
43
nodes.py
@ -1253,7 +1253,34 @@ class ToraEncodeOpticalFlow:
|
||||
return (tora, )
|
||||
|
||||
|
||||
class CogVideoXFasterCache:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"start_step": ("INT", {"default": 15, "min": 0, "max": 1024, "step": 1}),
|
||||
"hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}),
|
||||
"lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}),
|
||||
"cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FASTERCACHEARGS",)
|
||||
RETURN_NAMES = ("fastercache", )
|
||||
FUNCTION = "args"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def args(self, start_step, hf_step, lf_step, cache_device):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
fastercache = {
|
||||
"start_step" : start_step,
|
||||
"hf_step" : hf_step,
|
||||
"lf_step" : lf_step,
|
||||
"cache_device" : device if cache_device == "main_device" else offload_device
|
||||
}
|
||||
return (fastercache,)
|
||||
|
||||
class CogVideoSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1280,6 +1307,7 @@ class CogVideoSampler:
|
||||
"context_options": ("COGCONTEXT", ),
|
||||
"controlnet": ("COGVIDECONTROLNET",),
|
||||
"tora_trajectory": ("TORAFEATURES", ),
|
||||
"fastercache": ("FASTERCACHEARGS", ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1289,7 +1317,7 @@ class CogVideoSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None,
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None):
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
|
||||
mm.soft_empty_cache()
|
||||
|
||||
base_path = pipeline["base_path"]
|
||||
@ -1328,6 +1356,17 @@ class CogVideoSampler:
|
||||
padding = torch.zeros((negative.shape[0], target_length - negative.shape[1], negative.shape[2]), device=negative.device)
|
||||
negative = torch.cat((negative, padding), dim=1)
|
||||
|
||||
if fastercache is not None:
|
||||
pipe.transformer.use_fastercache = True
|
||||
pipe.transformer.fastercache_counter = 0
|
||||
pipe.transformer.fastercache_start_step = fastercache["start_step"]
|
||||
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
|
||||
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
|
||||
pipe.transformer.fastercache_device = fastercache["cache_device"]
|
||||
else:
|
||||
pipe.transformer.use_fastercache = False
|
||||
pipe.transformer.fastercache_counter = 0
|
||||
|
||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
with autocast_context:
|
||||
@ -1901,6 +1940,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
|
||||
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
||||
"CogVideoXFasterCache": CogVideoXFasterCache
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -1924,4 +1964,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
||||
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
||||
"CogVideoXFasterCache": "CogVideo XFasterCache"
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user