Add TeaCache

This commit is contained in:
kijai 2025-01-20 17:06:06 +02:00
parent 3d2ee02d83
commit 76f7930d07
2 changed files with 102 additions and 28 deletions

View File

@ -84,6 +84,7 @@ def set_attention_func(attention_mode, heads):
return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
return func
#for fastercache
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
@ -101,6 +102,13 @@ def fft(tensor):
return low_freq_fft, high_freq_fft
#for teacache
def poly1d(coefficients, x):
result = torch.zeros_like(x)
for i, coeff in enumerate(coefficients):
result += coeff * (x ** (len(coefficients) - 1 - i))
return result.abs()
#region Attention
class CogVideoXAttnProcessor2_0:
r"""
@ -526,7 +534,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
self.attention_mode = attention_mode
#tora
self.fuser_list = None
#fastercache
self.use_fastercache = False
self.fastercache_counter = 0
self.fastercache_start_step = 15
@ -534,7 +547,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = attention_mode
#teacache
self.use_teacache = False
self.teacache_rel_l1_thresh = 0.0
if not self.config.use_rotary_positional_embeddings:
#CogVideoX-2B
self.teacache_coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
else:
#CogVideoX-5B
self.teacache_coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
def _set_gradient_checkpointing(self, module, value=False):
@ -662,33 +684,55 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
output = torch.cat([output, recovered_uncond])
else:
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device
)
#has_nan = torch.isnan(hidden_states).any()
#if has_nan:
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
if self.use_teacache:
if not hasattr(self, 'accumulated_rel_l1_distance'):
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
self.accumulated_rel_l1_distance += poly1d(self.teacache_coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
if self.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
#print("self.accumulated_rel_l1_distance ", self.accumulated_rel_l1_distance)
self.previous_modulated_input = emb
if not should_calc:
hidden_states += self.previous_residual
encoder_hidden_states += self.previous_residual_encoder
if not self.use_teacache or (self.use_teacache and should_calc):
if self.use_teacache:
ori_hidden_states = hidden_states.clone()
ori_encoder_hidden_states = encoder_hidden_states.clone()
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device
)
#controlnet
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
print(controlnet_block_weight)
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
#controlnet
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
print(controlnet_block_weight)
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if self.use_teacache:
self.previous_residual = hidden_states - ori_hidden_states
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B

View File

@ -584,6 +584,26 @@ class CogVideoXFasterCache:
"num_blocks_to_cache" : num_blocks_to_cache,
}
return (fastercache,)
class CogVideoXTeaCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Cache threshold, higher values are faster while sacrificing quality"}),
}
}
RETURN_TYPES = ("TEACACHEARGS",)
RETURN_NAMES = ("teacache_args",)
FUNCTION = "args"
CATEGORY = "CogVideoWrapper"
def args(self, rel_l1_thresh):
teacache = {
"rel_l1_thresh": rel_l1_thresh
}
return (teacache,)
#region Sampler
class CogVideoSampler:
@ -612,6 +632,7 @@ class CogVideoSampler:
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ),
"feta_args": ("FETAARGS", ),
"teacache_args": ("TEACACHEARGS", ),
}
}
@ -621,7 +642,7 @@ class CogVideoSampler:
CATEGORY = "CogVideoWrapper"
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None, teacache_args=None):
mm.unload_all_models()
mm.soft_empty_cache()
@ -706,6 +727,13 @@ class CogVideoSampler:
pipe.transformer.use_fastercache = False
pipe.transformer.fastercache_counter = 0
if teacache_args is not None:
pipe.transformer.use_teacache = True
pipe.transformer.teacache_rel_l1_thresh = teacache_args["rel_l1_thresh"]
log.info(f"TeaCache enabled with rel_l1_thresh: {pipe.transformer.teacache_rel_l1_thresh}")
else:
pipe.transformer.use_teacache = False
if not isinstance(cfg, list):
cfg = [cfg for _ in range(steps)]
else:
@ -982,6 +1010,7 @@ NODE_CLASS_MAPPINGS = {
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
"CogVideoXTeaCache": CogVideoXTeaCache,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler",
@ -999,4 +1028,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
"CogVideoXTeaCache": "CogVideoX TeaCache",
}