mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-24 10:04:28 +08:00
update to match comfyui latest update
This commit is contained in:
parent
1016861aec
commit
c19ad34916
@ -712,24 +712,7 @@ def relative_l1_distance(last_tensor, current_tensor):
|
||||
relative_l1_distance = l1_distance / norm
|
||||
return relative_l1_distance.to(torch.float32)
|
||||
|
||||
#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently
|
||||
def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs):
|
||||
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
@ -749,9 +732,9 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
@torch.compiler.disable()
|
||||
def tea_cache(x, e0, e, kwargs):
|
||||
#teacache for cond and uncond separately
|
||||
rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"]
|
||||
rel_l1_thresh = transformer_options["rel_l1_thresh"]
|
||||
|
||||
is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False
|
||||
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
|
||||
|
||||
should_calc = True
|
||||
suffix = "cond" if is_cond else "uncond"
|
||||
@ -769,11 +752,11 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
cache = self.teacache_state[suffix]
|
||||
|
||||
if cache['prev_input'] is not None:
|
||||
if kwargs["transformer_options"]["coefficients"] == []:
|
||||
if transformer_options["coefficients"] == []:
|
||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||
else:
|
||||
rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"])
|
||||
rescale_func = np.poly1d(transformer_options["coefficients"])
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
|
||||
try:
|
||||
if curr_acc_dist < rel_l1_thresh:
|
||||
@ -786,7 +769,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
|
||||
if kwargs["transformer_options"]["coefficients"] == []:
|
||||
if transformer_options["coefficients"] == []:
|
||||
cache['prev_input'] = e0.clone().detach()
|
||||
else:
|
||||
cache['prev_input'] = e.clone().detach()
|
||||
@ -800,16 +783,20 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
should_calc, cache = tea_cache(x, e0, e, kwargs)
|
||||
if should_calc:
|
||||
original_x = x.clone().detach()
|
||||
# arguments
|
||||
block_wargs = dict(
|
||||
e=e0,
|
||||
freqs=freqs,
|
||||
context=context)
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **block_wargs)
|
||||
|
||||
cache['previous_residual'] = (x - original_x).to(kwargs["transformer_options"]["teacache_device"])
|
||||
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
@ -932,7 +919,6 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
|
||||
context = patch.multiple(
|
||||
diffusion_model,
|
||||
forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__),
|
||||
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
||||
)
|
||||
else:
|
||||
@ -961,7 +947,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.wan.model import WanSelfAttention
|
||||
|
||||
def modified_wan_self_attention_forward(self, x, freqs):
|
||||
r"""
|
||||
Args:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user