update to match comfyui latest update

This commit is contained in:
kijai 2025-03-14 18:42:31 +02:00
parent 1016861aec
commit c19ad34916

View File

@ -712,24 +712,7 @@ def relative_l1_distance(last_tensor, current_tensor):
relative_l1_distance = l1_distance / norm
return relative_l1_distance.to(torch.float32)
#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently
def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w]
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs):
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
@ -749,9 +732,9 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
@torch.compiler.disable()
def tea_cache(x, e0, e, kwargs):
#teacache for cond and uncond separately
rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"]
rel_l1_thresh = transformer_options["rel_l1_thresh"]
is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
should_calc = True
suffix = "cond" if is_cond else "uncond"
@ -769,11 +752,11 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
cache = self.teacache_state[suffix]
if cache['prev_input'] is not None:
if kwargs["transformer_options"]["coefficients"] == []:
if transformer_options["coefficients"] == []:
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
else:
rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"])
rescale_func = np.poly1d(transformer_options["coefficients"])
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
try:
if curr_acc_dist < rel_l1_thresh:
@ -786,7 +769,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
if kwargs["transformer_options"]["coefficients"] == []:
if transformer_options["coefficients"] == []:
cache['prev_input'] = e0.clone().detach()
else:
cache['prev_input'] = e.clone().detach()
@ -800,16 +783,20 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
should_calc, cache = tea_cache(x, e0, e, kwargs)
if should_calc:
original_x = x.clone().detach()
# arguments
block_wargs = dict(
e=e0,
freqs=freqs,
context=context)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
for block in self.blocks:
x = block(x, **block_wargs)
cache['previous_residual'] = (x - original_x).to(kwargs["transformer_options"]["teacache_device"])
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
# head
x = self.head(x, e)
@ -932,7 +919,6 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
context = patch.multiple(
diffusion_model,
forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__),
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
)
else:
@ -961,7 +947,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.wan.model import WanSelfAttention
def modified_wan_self_attention_forward(self, x, freqs):
r"""
Args: