diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index efbaecc70..cf63030f2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module): x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + patches = transformer_options.get("patches", {}) + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim def qkv_fn_q(x): @@ -86,8 +88,12 @@ class WanSelfAttention(nn.Module): transformer_options=transformer_options, ) + if "self_attn" in patches: + for p in patches["self_attn"]: + x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = self.o(x) - return x, q, k + return x class WanT2VCrossAttention(WanSelfAttention): @@ -234,7 +240,7 @@ class WanAttentionBlock(nn.Module): # assert e[0].dtype == torch.float32 # self-attention - y, q, k = self.self_attn( + y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs, transformer_options=transformer_options) @@ -246,7 +252,7 @@ class WanAttentionBlock(nn.Module): if "cross_attn" in patches: for p in patches["cross_attn"]: - x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = p({"x": x, "transformer_options": transformer_options}) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py index fa44382b0..a651618c8 100644 --- a/comfy/ldm/wan/model_multitalk.py +++ b/comfy/ldm/wan/model_multitalk.py @@ -391,6 +391,7 @@ class MultiTalkAudioProjModel(torch.nn.Module): return context_tokens + class WanMultiTalkAttentionBlock(torch.nn.Module): def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None): super().__init__() @@ -398,6 +399,20 @@ class WanMultiTalkAttentionBlock(torch.nn.Module): self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True) +class MultiTalkGetAttnMapPatch: + def __init__(self, ref_target_masks=None): + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + x = kwargs["x"] + + if self.ref_target_masks is not None: + x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) + transformer_options["x_ref_attn_map"] = x_ref_attn_map + return x + + class MultiTalkCrossAttnPatch: def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None): self.model_patch = model_patch @@ -412,17 +427,16 @@ class MultiTalkCrossAttnPatch: return torch.zeros_like(x) audio_embeds = transformer_options.get("audio_embeds") + x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None) - x_ref_attn_map = None - if self.ref_target_masks is not None: - x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) norm_x = self.model_patch.model.blocks[block_idx].norm_x(x) x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn( norm_x, audio_embeds.to(x.dtype), shape=transformer_options["grid_sizes"], x_ref_attn_map=x_ref_attn_map ) - return x_audio * self.audio_scale + x = x + x_audio * self.audio_scale + return x def models(self): return [self.model_patch] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index c82ba0e78..409025121 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1289,7 +1289,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode): return io.NodeOutput(out_latent) -from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, project_audio_features +from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features class WanInfiniteTalkToVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -1429,7 +1429,9 @@ class WanInfiniteTalkToVideo(io.ComfyNode): is_extend=previous_frames is not None, )) # add cross-attention patch - model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale, ref_target_masks=token_ref_target_masks), "cross_attn") + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "cross_attn") + if token_ref_target_masks is not None: + model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "self_attn") out_latent = {} out_latent["samples"] = latent