Add TeaCache

This commit is contained in:
kijai 2025-01-20 17:06:06 +02:00
parent 3d2ee02d83
commit 76f7930d07
2 changed files with 102 additions and 28 deletions

View File

@ -84,6 +84,7 @@ def set_attention_func(attention_mode, heads):
return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
return func return func
#for fastercache
def fft(tensor): def fft(tensor):
tensor_fft = torch.fft.fft2(tensor) tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft) tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
@ -101,6 +102,13 @@ def fft(tensor):
return low_freq_fft, high_freq_fft return low_freq_fft, high_freq_fft
#for teacache
def poly1d(coefficients, x):
result = torch.zeros_like(x)
for i, coeff in enumerate(coefficients):
result += coeff * (x ** (len(coefficients) - 1 - i))
return result.abs()
#region Attention #region Attention
class CogVideoXAttnProcessor2_0: class CogVideoXAttnProcessor2_0:
r""" r"""
@ -526,7 +534,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.attention_mode = attention_mode
#tora
self.fuser_list = None self.fuser_list = None
#fastercache
self.use_fastercache = False self.use_fastercache = False
self.fastercache_counter = 0 self.fastercache_counter = 0
self.fastercache_start_step = 15 self.fastercache_start_step = 15
@ -534,7 +547,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_hf_step = 30 self.fastercache_hf_step = 30
self.fastercache_device = "cuda" self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks) self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = attention_mode
#teacache
self.use_teacache = False
self.teacache_rel_l1_thresh = 0.0
if not self.config.use_rotary_positional_embeddings:
#CogVideoX-2B
self.teacache_coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
else:
#CogVideoX-5B
self.teacache_coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
@ -662,33 +684,55 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww) recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
output = torch.cat([output, recovered_uncond]) output = torch.cat([output, recovered_uncond])
else: else:
for i, block in enumerate(self.transformer_blocks): if self.use_teacache:
hidden_states, encoder_hidden_states = block( if not hasattr(self, 'accumulated_rel_l1_distance'):
hidden_states=hidden_states, should_calc = True
encoder_hidden_states=encoder_hidden_states, self.accumulated_rel_l1_distance = 0
temb=emb, else:
image_rotary_emb=image_rotary_emb, self.accumulated_rel_l1_distance += poly1d(self.teacache_coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, if self.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
fuser = self.fuser_list[i] if self.fuser_list is not None else None, should_calc = False
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, else:
fastercache_counter = self.fastercache_counter, should_calc = True
fastercache_start_step = self.fastercache_start_step, self.accumulated_rel_l1_distance = 0
fastercache_device = self.fastercache_device #print("self.accumulated_rel_l1_distance ", self.accumulated_rel_l1_distance)
) self.previous_modulated_input = emb
#has_nan = torch.isnan(hidden_states).any() if not should_calc:
#if has_nan: hidden_states += self.previous_residual
# raise ValueError(f"block output hidden_states has nan: {has_nan}") encoder_hidden_states += self.previous_residual_encoder
if not self.use_teacache or (self.use_teacache and should_calc):
if self.use_teacache:
ori_hidden_states = hidden_states.clone()
ori_encoder_hidden_states = encoder_hidden_states.clone()
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device
)
#controlnet #controlnet
if (controlnet_states is not None) and (i < len(controlnet_states)): if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i] controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0 controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i] controlnet_block_weight = controlnet_weights[i]
print(controlnet_block_weight) print(controlnet_block_weight)
elif isinstance(controlnet_weights, (float, int)): elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if self.use_teacache:
self.previous_residual = hidden_states - ori_hidden_states
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
if not self.config.use_rotary_positional_embeddings: if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B # CogVideoX-2B

View File

@ -584,6 +584,26 @@ class CogVideoXFasterCache:
"num_blocks_to_cache" : num_blocks_to_cache, "num_blocks_to_cache" : num_blocks_to_cache,
} }
return (fastercache,) return (fastercache,)
class CogVideoXTeaCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Cache threshold, higher values are faster while sacrificing quality"}),
}
}
RETURN_TYPES = ("TEACACHEARGS",)
RETURN_NAMES = ("teacache_args",)
FUNCTION = "args"
CATEGORY = "CogVideoWrapper"
def args(self, rel_l1_thresh):
teacache = {
"rel_l1_thresh": rel_l1_thresh
}
return (teacache,)
#region Sampler #region Sampler
class CogVideoSampler: class CogVideoSampler:
@ -612,6 +632,7 @@ class CogVideoSampler:
"tora_trajectory": ("TORAFEATURES", ), "tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ), "fastercache": ("FASTERCACHEARGS", ),
"feta_args": ("FETAARGS", ), "feta_args": ("FETAARGS", ),
"teacache_args": ("TEACACHEARGS", ),
} }
} }
@ -621,7 +642,7 @@ class CogVideoSampler:
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None): denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None, teacache_args=None):
mm.unload_all_models() mm.unload_all_models()
mm.soft_empty_cache() mm.soft_empty_cache()
@ -706,6 +727,13 @@ class CogVideoSampler:
pipe.transformer.use_fastercache = False pipe.transformer.use_fastercache = False
pipe.transformer.fastercache_counter = 0 pipe.transformer.fastercache_counter = 0
if teacache_args is not None:
pipe.transformer.use_teacache = True
pipe.transformer.teacache_rel_l1_thresh = teacache_args["rel_l1_thresh"]
log.info(f"TeaCache enabled with rel_l1_thresh: {pipe.transformer.teacache_rel_l1_thresh}")
else:
pipe.transformer.use_teacache = False
if not isinstance(cfg, list): if not isinstance(cfg, list):
cfg = [cfg for _ in range(steps)] cfg = [cfg for _ in range(steps)]
else: else:
@ -982,6 +1010,7 @@ NODE_CLASS_MAPPINGS = {
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings, "CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP, "CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo, "CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
"CogVideoXTeaCache": CogVideoXTeaCache,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler", "CogVideoSampler": "CogVideo Sampler",
@ -999,4 +1028,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings", "CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP", "CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video", "CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
"CogVideoXTeaCache": "CogVideoX TeaCache",
} }