Support VACE with TeaCache

This commit is contained in:
kijai 2025-06-09 02:58:41 +03:00
parent 5736669288
commit e96a028254

View File

@ -990,6 +990,125 @@ def relative_l1_distance(last_tensor, current_tensor):
relative_l1_distance = l1_distance / norm
return relative_l1_distance.to(torch.float32)
@torch.compiler.disable()
def tea_cache(self, x, e0, e, transformer_options):
#teacache for cond and uncond separately
rel_l1_thresh = transformer_options["rel_l1_thresh"]
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
should_calc = True
suffix = "cond" if is_cond else "uncond"
# Init cache dict if not exists
if not hasattr(self, 'teacache_state'):
self.teacache_state = {
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
'teacache_skipped_steps': 0, 'previous_residual': None},
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
'teacache_skipped_steps': 0, 'previous_residual': None}
}
logging.info("\nTeaCache: Initialized")
cache = self.teacache_state[suffix]
if cache['prev_input'] is not None:
if transformer_options["coefficients"] == []:
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
else:
rescale_func = np.poly1d(transformer_options["coefficients"])
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
try:
if curr_acc_dist < rel_l1_thresh:
should_calc = False
cache['accumulated_rel_l1_distance'] = curr_acc_dist
else:
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
except:
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
if transformer_options["coefficients"] == []:
cache['prev_input'] = e0.clone().detach()
else:
cache['prev_input'] = e.clone().detach()
if not should_calc:
x += cache['previous_residual'].to(x.device)
cache['teacache_skipped_steps'] += 1
#print(f"TeaCache: Skipping {suffix} step")
return should_calc, cache
def teacache_wanvideo_vace_forward_orig(self, x, t, context, vace_context, vace_strength, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
orig_shape = list(vace_context.shape)
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
c = c.flatten(2).transpose(1, 2)
c = list(c.split(orig_shape[0], dim=0))
if not transformer_options:
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
teacache_enabled = transformer_options.get("teacache_enabled", False)
if not teacache_enabled:
should_calc = True
else:
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
if should_calc:
original_x = x.clone().detach()
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=original_x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength[iii]
del c_skip
if teacache_enabled:
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
@ -1003,69 +1122,20 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
@torch.compiler.disable()
def tea_cache(x, e0, e, kwargs):
#teacache for cond and uncond separately
rel_l1_thresh = transformer_options["rel_l1_thresh"]
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
should_calc = True
suffix = "cond" if is_cond else "uncond"
# Init cache dict if not exists
if not hasattr(self, 'teacache_state'):
self.teacache_state = {
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
'teacache_skipped_steps': 0, 'previous_residual': None},
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
'teacache_skipped_steps': 0, 'previous_residual': None}
}
logging.info("\nTeaCache: Initialized")
cache = self.teacache_state[suffix]
if cache['prev_input'] is not None:
if transformer_options["coefficients"] == []:
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
else:
rescale_func = np.poly1d(transformer_options["coefficients"])
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
try:
if curr_acc_dist < rel_l1_thresh:
should_calc = False
cache['accumulated_rel_l1_distance'] = curr_acc_dist
else:
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
except:
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
if transformer_options["coefficients"] == []:
cache['prev_input'] = e0.clone().detach()
else:
cache['prev_input'] = e.clone().detach()
if not should_calc:
x += cache['previous_residual'].to(x.device)
cache['teacache_skipped_steps'] += 1
#print(f"TeaCache: Skipping {suffix} step")
return should_calc, cache
if not transformer_options:
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
teacache_enabled = transformer_options.get("teacache_enabled", False)
if not teacache_enabled:
should_calc = True
else:
should_calc, cache = tea_cache(x, e0, e, kwargs)
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
if should_calc:
original_x = x.clone().detach()
@ -1075,12 +1145,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
if teacache_enabled:
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
@ -1206,9 +1276,10 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
if start_percent <= current_percent <= end_percent:
c["transformer_options"]["teacache_enabled"] = True
forward_function = teacache_wanvideo_vace_forward_orig if hasattr(diffusion_model, "vace_layers") else teacache_wanvideo_forward_orig
context = patch.multiple(
diffusion_model,
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
forward_orig=forward_function.__get__(diffusion_model, diffusion_model.__class__)
)
with context: