mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-06-08 19:31:20 +08:00
Support VACE with TeaCache
This commit is contained in:
parent
5736669288
commit
e96a028254
@ -990,6 +990,125 @@ def relative_l1_distance(last_tensor, current_tensor):
|
|||||||
relative_l1_distance = l1_distance / norm
|
relative_l1_distance = l1_distance / norm
|
||||||
return relative_l1_distance.to(torch.float32)
|
return relative_l1_distance.to(torch.float32)
|
||||||
|
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def tea_cache(self, x, e0, e, transformer_options):
|
||||||
|
#teacache for cond and uncond separately
|
||||||
|
rel_l1_thresh = transformer_options["rel_l1_thresh"]
|
||||||
|
|
||||||
|
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
|
||||||
|
|
||||||
|
should_calc = True
|
||||||
|
suffix = "cond" if is_cond else "uncond"
|
||||||
|
|
||||||
|
# Init cache dict if not exists
|
||||||
|
if not hasattr(self, 'teacache_state'):
|
||||||
|
self.teacache_state = {
|
||||||
|
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||||
|
'teacache_skipped_steps': 0, 'previous_residual': None},
|
||||||
|
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||||
|
'teacache_skipped_steps': 0, 'previous_residual': None}
|
||||||
|
}
|
||||||
|
logging.info("\nTeaCache: Initialized")
|
||||||
|
|
||||||
|
cache = self.teacache_state[suffix]
|
||||||
|
|
||||||
|
if cache['prev_input'] is not None:
|
||||||
|
if transformer_options["coefficients"] == []:
|
||||||
|
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||||
|
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(transformer_options["coefficients"])
|
||||||
|
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
|
||||||
|
try:
|
||||||
|
if curr_acc_dist < rel_l1_thresh:
|
||||||
|
should_calc = False
|
||||||
|
cache['accumulated_rel_l1_distance'] = curr_acc_dist
|
||||||
|
else:
|
||||||
|
should_calc = True
|
||||||
|
cache['accumulated_rel_l1_distance'] = 0
|
||||||
|
except:
|
||||||
|
should_calc = True
|
||||||
|
cache['accumulated_rel_l1_distance'] = 0
|
||||||
|
|
||||||
|
if transformer_options["coefficients"] == []:
|
||||||
|
cache['prev_input'] = e0.clone().detach()
|
||||||
|
else:
|
||||||
|
cache['prev_input'] = e.clone().detach()
|
||||||
|
|
||||||
|
if not should_calc:
|
||||||
|
x += cache['previous_residual'].to(x.device)
|
||||||
|
cache['teacache_skipped_steps'] += 1
|
||||||
|
#print(f"TeaCache: Skipping {suffix} step")
|
||||||
|
return should_calc, cache
|
||||||
|
|
||||||
|
def teacache_wanvideo_vace_forward_orig(self, x, t, context, vace_context, vace_strength, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
orig_shape = list(vace_context.shape)
|
||||||
|
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
|
||||||
|
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||||
|
c = c.flatten(2).transpose(1, 2)
|
||||||
|
c = list(c.split(orig_shape[0], dim=0))
|
||||||
|
|
||||||
|
if not transformer_options:
|
||||||
|
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
|
||||||
|
|
||||||
|
teacache_enabled = transformer_options.get("teacache_enabled", False)
|
||||||
|
if not teacache_enabled:
|
||||||
|
should_calc = True
|
||||||
|
else:
|
||||||
|
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
|
||||||
|
|
||||||
|
if should_calc:
|
||||||
|
original_x = x.clone().detach()
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
|
ii = self.vace_layers_mapping.get(i, None)
|
||||||
|
if ii is not None:
|
||||||
|
for iii in range(len(c)):
|
||||||
|
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=original_x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
x += c_skip * vace_strength[iii]
|
||||||
|
del c_skip
|
||||||
|
|
||||||
|
if teacache_enabled:
|
||||||
|
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
|
||||||
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
|
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
|
||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
@ -1003,69 +1122,20 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
|||||||
|
|
||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
if clip_fea is not None and self.img_emb is not None:
|
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
context_img_len = None
|
||||||
def tea_cache(x, e0, e, kwargs):
|
if clip_fea is not None:
|
||||||
#teacache for cond and uncond separately
|
if self.img_emb is not None:
|
||||||
rel_l1_thresh = transformer_options["rel_l1_thresh"]
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
|
|
||||||
|
|
||||||
should_calc = True
|
|
||||||
suffix = "cond" if is_cond else "uncond"
|
|
||||||
|
|
||||||
# Init cache dict if not exists
|
|
||||||
if not hasattr(self, 'teacache_state'):
|
|
||||||
self.teacache_state = {
|
|
||||||
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
|
||||||
'teacache_skipped_steps': 0, 'previous_residual': None},
|
|
||||||
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
|
||||||
'teacache_skipped_steps': 0, 'previous_residual': None}
|
|
||||||
}
|
|
||||||
logging.info("\nTeaCache: Initialized")
|
|
||||||
|
|
||||||
cache = self.teacache_state[suffix]
|
|
||||||
|
|
||||||
if cache['prev_input'] is not None:
|
|
||||||
if transformer_options["coefficients"] == []:
|
|
||||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
|
||||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
|
||||||
else:
|
|
||||||
rescale_func = np.poly1d(transformer_options["coefficients"])
|
|
||||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
|
|
||||||
try:
|
|
||||||
if curr_acc_dist < rel_l1_thresh:
|
|
||||||
should_calc = False
|
|
||||||
cache['accumulated_rel_l1_distance'] = curr_acc_dist
|
|
||||||
else:
|
|
||||||
should_calc = True
|
|
||||||
cache['accumulated_rel_l1_distance'] = 0
|
|
||||||
except:
|
|
||||||
should_calc = True
|
|
||||||
cache['accumulated_rel_l1_distance'] = 0
|
|
||||||
|
|
||||||
if transformer_options["coefficients"] == []:
|
|
||||||
cache['prev_input'] = e0.clone().detach()
|
|
||||||
else:
|
|
||||||
cache['prev_input'] = e.clone().detach()
|
|
||||||
|
|
||||||
if not should_calc:
|
|
||||||
x += cache['previous_residual'].to(x.device)
|
|
||||||
cache['teacache_skipped_steps'] += 1
|
|
||||||
#print(f"TeaCache: Skipping {suffix} step")
|
|
||||||
return should_calc, cache
|
|
||||||
|
|
||||||
if not transformer_options:
|
|
||||||
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
|
|
||||||
|
|
||||||
teacache_enabled = transformer_options.get("teacache_enabled", False)
|
teacache_enabled = transformer_options.get("teacache_enabled", False)
|
||||||
if not teacache_enabled:
|
if not teacache_enabled:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
else:
|
else:
|
||||||
should_calc, cache = tea_cache(x, e0, e, kwargs)
|
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
|
||||||
|
|
||||||
if should_calc:
|
if should_calc:
|
||||||
original_x = x.clone().detach()
|
original_x = x.clone().detach()
|
||||||
@ -1075,12 +1145,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
if teacache_enabled:
|
if teacache_enabled:
|
||||||
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
|
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
|
||||||
@ -1206,9 +1276,10 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
|||||||
if start_percent <= current_percent <= end_percent:
|
if start_percent <= current_percent <= end_percent:
|
||||||
c["transformer_options"]["teacache_enabled"] = True
|
c["transformer_options"]["teacache_enabled"] = True
|
||||||
|
|
||||||
|
forward_function = teacache_wanvideo_vace_forward_orig if hasattr(diffusion_model, "vace_layers") else teacache_wanvideo_forward_orig
|
||||||
context = patch.multiple(
|
context = patch.multiple(
|
||||||
diffusion_model,
|
diffusion_model,
|
||||||
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
forward_orig=forward_function.__get__(diffusion_model, diffusion_model.__class__)
|
||||||
)
|
)
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user