Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn

This commit is contained in:
kijai 2025-11-05 17:24:13 +02:00
parent 7237c36fba
commit fb099a40b2
3 changed files with 31 additions and 9 deletions

View File

@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
patches = transformer_options.get("patches", {})
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
def qkv_fn_q(x):
@ -86,8 +88,12 @@ class WanSelfAttention(nn.Module):
transformer_options=transformer_options,
)
if "self_attn" in patches:
for p in patches["self_attn"]:
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
x = self.o(x)
return x, q, k
return x
class WanT2VCrossAttention(WanSelfAttention):
@ -234,7 +240,7 @@ class WanAttentionBlock(nn.Module):
# assert e[0].dtype == torch.float32
# self-attention
y, q, k = self.self_attn(
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
@ -246,7 +252,7 @@ class WanAttentionBlock(nn.Module):
if "cross_attn" in patches:
for p in patches["cross_attn"]:
x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
x = p({"x": x, "transformer_options": transformer_options})
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))

View File

@ -391,6 +391,7 @@ class MultiTalkAudioProjModel(torch.nn.Module):
return context_tokens
class WanMultiTalkAttentionBlock(torch.nn.Module):
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
super().__init__()
@ -398,6 +399,20 @@ class WanMultiTalkAttentionBlock(torch.nn.Module):
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
class MultiTalkGetAttnMapPatch:
def __init__(self, ref_target_masks=None):
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
transformer_options = kwargs.get("transformer_options", {})
x = kwargs["x"]
if self.ref_target_masks is not None:
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
transformer_options["x_ref_attn_map"] = x_ref_attn_map
return x
class MultiTalkCrossAttnPatch:
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
self.model_patch = model_patch
@ -412,17 +427,16 @@ class MultiTalkCrossAttnPatch:
return torch.zeros_like(x)
audio_embeds = transformer_options.get("audio_embeds")
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
x_ref_attn_map = None
if self.ref_target_masks is not None:
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
norm_x, audio_embeds.to(x.dtype),
shape=transformer_options["grid_sizes"],
x_ref_attn_map=x_ref_attn_map
)
return x_audio * self.audio_scale
x = x + x_audio * self.audio_scale
return x
def models(self):
return [self.model_patch]

View File

@ -1289,7 +1289,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
return io.NodeOutput(out_latent)
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, project_audio_features
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
class WanInfiniteTalkToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -1429,7 +1429,9 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
is_extend=previous_frames is not None,
))
# add cross-attention patch
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale, ref_target_masks=token_ref_target_masks), "cross_attn")
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "cross_attn")
if token_ref_target_masks is not None:
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "self_attn")
out_latent = {}
out_latent["samples"] = latent