mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
Add TeaCache
This commit is contained in:
parent
3d2ee02d83
commit
76f7930d07
@ -84,6 +84,7 @@ def set_attention_func(attention_mode, heads):
|
|||||||
return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
#for fastercache
|
||||||
def fft(tensor):
|
def fft(tensor):
|
||||||
tensor_fft = torch.fft.fft2(tensor)
|
tensor_fft = torch.fft.fft2(tensor)
|
||||||
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
||||||
@ -101,6 +102,13 @@ def fft(tensor):
|
|||||||
|
|
||||||
return low_freq_fft, high_freq_fft
|
return low_freq_fft, high_freq_fft
|
||||||
|
|
||||||
|
#for teacache
|
||||||
|
def poly1d(coefficients, x):
|
||||||
|
result = torch.zeros_like(x)
|
||||||
|
for i, coeff in enumerate(coefficients):
|
||||||
|
result += coeff * (x ** (len(coefficients) - 1 - i))
|
||||||
|
return result.abs()
|
||||||
|
|
||||||
#region Attention
|
#region Attention
|
||||||
class CogVideoXAttnProcessor2_0:
|
class CogVideoXAttnProcessor2_0:
|
||||||
r"""
|
r"""
|
||||||
@ -526,7 +534,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
|
#tora
|
||||||
self.fuser_list = None
|
self.fuser_list = None
|
||||||
|
|
||||||
|
#fastercache
|
||||||
self.use_fastercache = False
|
self.use_fastercache = False
|
||||||
self.fastercache_counter = 0
|
self.fastercache_counter = 0
|
||||||
self.fastercache_start_step = 15
|
self.fastercache_start_step = 15
|
||||||
@ -534,7 +547,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
self.fastercache_hf_step = 30
|
self.fastercache_hf_step = 30
|
||||||
self.fastercache_device = "cuda"
|
self.fastercache_device = "cuda"
|
||||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||||
self.attention_mode = attention_mode
|
|
||||||
|
#teacache
|
||||||
|
self.use_teacache = False
|
||||||
|
self.teacache_rel_l1_thresh = 0.0
|
||||||
|
if not self.config.use_rotary_positional_embeddings:
|
||||||
|
#CogVideoX-2B
|
||||||
|
self.teacache_coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
|
||||||
|
else:
|
||||||
|
#CogVideoX-5B
|
||||||
|
self.teacache_coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
|
||||||
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
@ -662,33 +684,55 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
output = torch.cat([output, recovered_uncond])
|
output = torch.cat([output, recovered_uncond])
|
||||||
else:
|
else:
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
if self.use_teacache:
|
||||||
hidden_states, encoder_hidden_states = block(
|
if not hasattr(self, 'accumulated_rel_l1_distance'):
|
||||||
hidden_states=hidden_states,
|
should_calc = True
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
self.accumulated_rel_l1_distance = 0
|
||||||
temb=emb,
|
else:
|
||||||
image_rotary_emb=image_rotary_emb,
|
self.accumulated_rel_l1_distance += poly1d(self.teacache_coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
|
||||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
if self.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
|
||||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
should_calc = False
|
||||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
else:
|
||||||
fastercache_counter = self.fastercache_counter,
|
should_calc = True
|
||||||
fastercache_start_step = self.fastercache_start_step,
|
self.accumulated_rel_l1_distance = 0
|
||||||
fastercache_device = self.fastercache_device
|
#print("self.accumulated_rel_l1_distance ", self.accumulated_rel_l1_distance)
|
||||||
)
|
self.previous_modulated_input = emb
|
||||||
#has_nan = torch.isnan(hidden_states).any()
|
if not should_calc:
|
||||||
#if has_nan:
|
hidden_states += self.previous_residual
|
||||||
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
|
encoder_hidden_states += self.previous_residual_encoder
|
||||||
|
|
||||||
|
if not self.use_teacache or (self.use_teacache and should_calc):
|
||||||
|
if self.use_teacache:
|
||||||
|
ori_hidden_states = hidden_states.clone()
|
||||||
|
ori_encoder_hidden_states = encoder_hidden_states.clone()
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
hidden_states, encoder_hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
temb=emb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||||
|
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||||
|
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||||
|
fastercache_counter = self.fastercache_counter,
|
||||||
|
fastercache_start_step = self.fastercache_start_step,
|
||||||
|
fastercache_device = self.fastercache_device
|
||||||
|
)
|
||||||
|
|
||||||
#controlnet
|
#controlnet
|
||||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||||
controlnet_states_block = controlnet_states[i]
|
controlnet_states_block = controlnet_states[i]
|
||||||
controlnet_block_weight = 1.0
|
controlnet_block_weight = 1.0
|
||||||
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
||||||
controlnet_block_weight = controlnet_weights[i]
|
controlnet_block_weight = controlnet_weights[i]
|
||||||
print(controlnet_block_weight)
|
print(controlnet_block_weight)
|
||||||
elif isinstance(controlnet_weights, (float, int)):
|
elif isinstance(controlnet_weights, (float, int)):
|
||||||
controlnet_block_weight = controlnet_weights
|
controlnet_block_weight = controlnet_weights
|
||||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||||
|
|
||||||
|
if self.use_teacache:
|
||||||
|
self.previous_residual = hidden_states - ori_hidden_states
|
||||||
|
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
||||||
|
|
||||||
if not self.config.use_rotary_positional_embeddings:
|
if not self.config.use_rotary_positional_embeddings:
|
||||||
# CogVideoX-2B
|
# CogVideoX-2B
|
||||||
|
|||||||
32
nodes.py
32
nodes.py
@ -584,6 +584,26 @@ class CogVideoXFasterCache:
|
|||||||
"num_blocks_to_cache" : num_blocks_to_cache,
|
"num_blocks_to_cache" : num_blocks_to_cache,
|
||||||
}
|
}
|
||||||
return (fastercache,)
|
return (fastercache,)
|
||||||
|
|
||||||
|
class CogVideoXTeaCache:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Cache threshold, higher values are faster while sacrificing quality"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("TEACACHEARGS",)
|
||||||
|
RETURN_NAMES = ("teacache_args",)
|
||||||
|
FUNCTION = "args"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
|
def args(self, rel_l1_thresh):
|
||||||
|
teacache = {
|
||||||
|
"rel_l1_thresh": rel_l1_thresh
|
||||||
|
}
|
||||||
|
return (teacache,)
|
||||||
|
|
||||||
#region Sampler
|
#region Sampler
|
||||||
class CogVideoSampler:
|
class CogVideoSampler:
|
||||||
@ -612,6 +632,7 @@ class CogVideoSampler:
|
|||||||
"tora_trajectory": ("TORAFEATURES", ),
|
"tora_trajectory": ("TORAFEATURES", ),
|
||||||
"fastercache": ("FASTERCACHEARGS", ),
|
"fastercache": ("FASTERCACHEARGS", ),
|
||||||
"feta_args": ("FETAARGS", ),
|
"feta_args": ("FETAARGS", ),
|
||||||
|
"teacache_args": ("TEACACHEARGS", ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -621,7 +642,7 @@ class CogVideoSampler:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
||||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None):
|
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None, teacache_args=None):
|
||||||
mm.unload_all_models()
|
mm.unload_all_models()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
@ -706,6 +727,13 @@ class CogVideoSampler:
|
|||||||
pipe.transformer.use_fastercache = False
|
pipe.transformer.use_fastercache = False
|
||||||
pipe.transformer.fastercache_counter = 0
|
pipe.transformer.fastercache_counter = 0
|
||||||
|
|
||||||
|
if teacache_args is not None:
|
||||||
|
pipe.transformer.use_teacache = True
|
||||||
|
pipe.transformer.teacache_rel_l1_thresh = teacache_args["rel_l1_thresh"]
|
||||||
|
log.info(f"TeaCache enabled with rel_l1_thresh: {pipe.transformer.teacache_rel_l1_thresh}")
|
||||||
|
else:
|
||||||
|
pipe.transformer.use_teacache = False
|
||||||
|
|
||||||
if not isinstance(cfg, list):
|
if not isinstance(cfg, list):
|
||||||
cfg = [cfg for _ in range(steps)]
|
cfg = [cfg for _ in range(steps)]
|
||||||
else:
|
else:
|
||||||
@ -982,6 +1010,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
|
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
|
||||||
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
|
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
|
||||||
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
|
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
|
||||||
|
"CogVideoXTeaCache": CogVideoXTeaCache,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"CogVideoSampler": "CogVideo Sampler",
|
"CogVideoSampler": "CogVideo Sampler",
|
||||||
@ -999,4 +1028,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
|
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
|
||||||
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
|
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
|
||||||
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
|
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
|
||||||
|
"CogVideoXTeaCache": "CogVideoX TeaCache",
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user