mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
Add TeaCache
This commit is contained in:
parent
3d2ee02d83
commit
76f7930d07
@ -84,6 +84,7 @@ def set_attention_func(attention_mode, heads):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
||||
return func
|
||||
|
||||
#for fastercache
|
||||
def fft(tensor):
|
||||
tensor_fft = torch.fft.fft2(tensor)
|
||||
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
||||
@ -101,6 +102,13 @@ def fft(tensor):
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
|
||||
#for teacache
|
||||
def poly1d(coefficients, x):
|
||||
result = torch.zeros_like(x)
|
||||
for i, coeff in enumerate(coefficients):
|
||||
result += coeff * (x ** (len(coefficients) - 1 - i))
|
||||
return result.abs()
|
||||
|
||||
#region Attention
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
@ -526,7 +534,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
#tora
|
||||
self.fuser_list = None
|
||||
|
||||
#fastercache
|
||||
self.use_fastercache = False
|
||||
self.fastercache_counter = 0
|
||||
self.fastercache_start_step = 15
|
||||
@ -534,7 +547,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.fastercache_hf_step = 30
|
||||
self.fastercache_device = "cuda"
|
||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
#teacache
|
||||
self.use_teacache = False
|
||||
self.teacache_rel_l1_thresh = 0.0
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
#CogVideoX-2B
|
||||
self.teacache_coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
|
||||
else:
|
||||
#CogVideoX-5B
|
||||
self.teacache_coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
|
||||
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
@ -662,33 +684,55 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
output = torch.cat([output, recovered_uncond])
|
||||
else:
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
#has_nan = torch.isnan(hidden_states).any()
|
||||
#if has_nan:
|
||||
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
|
||||
if self.use_teacache:
|
||||
if not hasattr(self, 'accumulated_rel_l1_distance'):
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
self.accumulated_rel_l1_distance += poly1d(self.teacache_coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
|
||||
if self.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
#print("self.accumulated_rel_l1_distance ", self.accumulated_rel_l1_distance)
|
||||
self.previous_modulated_input = emb
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
encoder_hidden_states += self.previous_residual_encoder
|
||||
|
||||
if not self.use_teacache or (self.use_teacache and should_calc):
|
||||
if self.use_teacache:
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
ori_encoder_hidden_states = encoder_hidden_states.clone()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
|
||||
#controlnet
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
controlnet_block_weight = 1.0
|
||||
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
||||
controlnet_block_weight = controlnet_weights[i]
|
||||
print(controlnet_block_weight)
|
||||
elif isinstance(controlnet_weights, (float, int)):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
#controlnet
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
controlnet_block_weight = 1.0
|
||||
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
||||
controlnet_block_weight = controlnet_weights[i]
|
||||
print(controlnet_block_weight)
|
||||
elif isinstance(controlnet_weights, (float, int)):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
|
||||
if self.use_teacache:
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
|
||||
32
nodes.py
32
nodes.py
@ -584,6 +584,26 @@ class CogVideoXFasterCache:
|
||||
"num_blocks_to_cache" : num_blocks_to_cache,
|
||||
}
|
||||
return (fastercache,)
|
||||
|
||||
class CogVideoXTeaCache:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Cache threshold, higher values are faster while sacrificing quality"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TEACACHEARGS",)
|
||||
RETURN_NAMES = ("teacache_args",)
|
||||
FUNCTION = "args"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def args(self, rel_l1_thresh):
|
||||
teacache = {
|
||||
"rel_l1_thresh": rel_l1_thresh
|
||||
}
|
||||
return (teacache,)
|
||||
|
||||
#region Sampler
|
||||
class CogVideoSampler:
|
||||
@ -612,6 +632,7 @@ class CogVideoSampler:
|
||||
"tora_trajectory": ("TORAFEATURES", ),
|
||||
"fastercache": ("FASTERCACHEARGS", ),
|
||||
"feta_args": ("FETAARGS", ),
|
||||
"teacache_args": ("TEACACHEARGS", ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -621,7 +642,7 @@ class CogVideoSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None, teacache_args=None):
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
@ -706,6 +727,13 @@ class CogVideoSampler:
|
||||
pipe.transformer.use_fastercache = False
|
||||
pipe.transformer.fastercache_counter = 0
|
||||
|
||||
if teacache_args is not None:
|
||||
pipe.transformer.use_teacache = True
|
||||
pipe.transformer.teacache_rel_l1_thresh = teacache_args["rel_l1_thresh"]
|
||||
log.info(f"TeaCache enabled with rel_l1_thresh: {pipe.transformer.teacache_rel_l1_thresh}")
|
||||
else:
|
||||
pipe.transformer.use_teacache = False
|
||||
|
||||
if not isinstance(cfg, list):
|
||||
cfg = [cfg for _ in range(steps)]
|
||||
else:
|
||||
@ -982,6 +1010,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
|
||||
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
|
||||
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
|
||||
"CogVideoXTeaCache": CogVideoXTeaCache,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoSampler": "CogVideo Sampler",
|
||||
@ -999,4 +1028,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
|
||||
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
|
||||
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
|
||||
"CogVideoXTeaCache": "CogVideoX TeaCache",
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user