mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-27 06:10:55 +08:00
Support VACE with TeaCache
This commit is contained in:
parent
5736669288
commit
e96a028254
@ -990,6 +990,125 @@ def relative_l1_distance(last_tensor, current_tensor):
|
||||
relative_l1_distance = l1_distance / norm
|
||||
return relative_l1_distance.to(torch.float32)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def tea_cache(self, x, e0, e, transformer_options):
|
||||
#teacache for cond and uncond separately
|
||||
rel_l1_thresh = transformer_options["rel_l1_thresh"]
|
||||
|
||||
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
|
||||
|
||||
should_calc = True
|
||||
suffix = "cond" if is_cond else "uncond"
|
||||
|
||||
# Init cache dict if not exists
|
||||
if not hasattr(self, 'teacache_state'):
|
||||
self.teacache_state = {
|
||||
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||
'teacache_skipped_steps': 0, 'previous_residual': None},
|
||||
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||
'teacache_skipped_steps': 0, 'previous_residual': None}
|
||||
}
|
||||
logging.info("\nTeaCache: Initialized")
|
||||
|
||||
cache = self.teacache_state[suffix]
|
||||
|
||||
if cache['prev_input'] is not None:
|
||||
if transformer_options["coefficients"] == []:
|
||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||
else:
|
||||
rescale_func = np.poly1d(transformer_options["coefficients"])
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
|
||||
try:
|
||||
if curr_acc_dist < rel_l1_thresh:
|
||||
should_calc = False
|
||||
cache['accumulated_rel_l1_distance'] = curr_acc_dist
|
||||
else:
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
except:
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
|
||||
if transformer_options["coefficients"] == []:
|
||||
cache['prev_input'] = e0.clone().detach()
|
||||
else:
|
||||
cache['prev_input'] = e.clone().detach()
|
||||
|
||||
if not should_calc:
|
||||
x += cache['previous_residual'].to(x.device)
|
||||
cache['teacache_skipped_steps'] += 1
|
||||
#print(f"TeaCache: Skipping {suffix} step")
|
||||
return should_calc, cache
|
||||
|
||||
def teacache_wanvideo_vace_forward_orig(self, x, t, context, vace_context, vace_strength, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
orig_shape = list(vace_context.shape)
|
||||
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
|
||||
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||
c = c.flatten(2).transpose(1, 2)
|
||||
c = list(c.split(orig_shape[0], dim=0))
|
||||
|
||||
if not transformer_options:
|
||||
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
|
||||
|
||||
teacache_enabled = transformer_options.get("teacache_enabled", False)
|
||||
if not teacache_enabled:
|
||||
should_calc = True
|
||||
else:
|
||||
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
|
||||
|
||||
if should_calc:
|
||||
original_x = x.clone().detach()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
ii = self.vace_layers_mapping.get(i, None)
|
||||
if ii is not None:
|
||||
for iii in range(len(c)):
|
||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=original_x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength[iii]
|
||||
del c_skip
|
||||
|
||||
if teacache_enabled:
|
||||
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
@ -1003,69 +1122,20 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def tea_cache(x, e0, e, kwargs):
|
||||
#teacache for cond and uncond separately
|
||||
rel_l1_thresh = transformer_options["rel_l1_thresh"]
|
||||
|
||||
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
should_calc = True
|
||||
suffix = "cond" if is_cond else "uncond"
|
||||
|
||||
# Init cache dict if not exists
|
||||
if not hasattr(self, 'teacache_state'):
|
||||
self.teacache_state = {
|
||||
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||
'teacache_skipped_steps': 0, 'previous_residual': None},
|
||||
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||
'teacache_skipped_steps': 0, 'previous_residual': None}
|
||||
}
|
||||
logging.info("\nTeaCache: Initialized")
|
||||
|
||||
cache = self.teacache_state[suffix]
|
||||
|
||||
if cache['prev_input'] is not None:
|
||||
if transformer_options["coefficients"] == []:
|
||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||
else:
|
||||
rescale_func = np.poly1d(transformer_options["coefficients"])
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
|
||||
try:
|
||||
if curr_acc_dist < rel_l1_thresh:
|
||||
should_calc = False
|
||||
cache['accumulated_rel_l1_distance'] = curr_acc_dist
|
||||
else:
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
except:
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
|
||||
if transformer_options["coefficients"] == []:
|
||||
cache['prev_input'] = e0.clone().detach()
|
||||
else:
|
||||
cache['prev_input'] = e.clone().detach()
|
||||
|
||||
if not should_calc:
|
||||
x += cache['previous_residual'].to(x.device)
|
||||
cache['teacache_skipped_steps'] += 1
|
||||
#print(f"TeaCache: Skipping {suffix} step")
|
||||
return should_calc, cache
|
||||
|
||||
if not transformer_options:
|
||||
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
|
||||
|
||||
teacache_enabled = transformer_options.get("teacache_enabled", False)
|
||||
if not teacache_enabled:
|
||||
should_calc = True
|
||||
else:
|
||||
should_calc, cache = tea_cache(x, e0, e, kwargs)
|
||||
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
|
||||
|
||||
if should_calc:
|
||||
original_x = x.clone().detach()
|
||||
@ -1075,12 +1145,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
if teacache_enabled:
|
||||
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
|
||||
@ -1206,9 +1276,10 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
if start_percent <= current_percent <= end_percent:
|
||||
c["transformer_options"]["teacache_enabled"] = True
|
||||
|
||||
forward_function = teacache_wanvideo_vace_forward_orig if hasattr(diffusion_model, "vace_layers") else teacache_wanvideo_forward_orig
|
||||
context = patch.multiple(
|
||||
diffusion_model,
|
||||
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
||||
forward_orig=forward_function.__get__(diffusion_model, diffusion_model.__class__)
|
||||
)
|
||||
|
||||
with context:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user