mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn
This commit is contained in:
parent
7237c36fba
commit
fb099a40b2
@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
def qkv_fn_q(x):
|
||||
@ -86,8 +88,12 @@ class WanSelfAttention(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "self_attn" in patches:
|
||||
for p in patches["self_attn"]:
|
||||
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
|
||||
x = self.o(x)
|
||||
return x, q, k
|
||||
return x
|
||||
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
@ -234,7 +240,7 @@ class WanAttentionBlock(nn.Module):
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y, q, k = self.self_attn(
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
@ -246,7 +252,7 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
if "cross_attn" in patches:
|
||||
for p in patches["cross_attn"]:
|
||||
x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
x = p({"x": x, "transformer_options": transformer_options})
|
||||
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
|
||||
@ -391,6 +391,7 @@ class MultiTalkAudioProjModel(torch.nn.Module):
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
@ -398,6 +399,20 @@ class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||
|
||||
|
||||
class MultiTalkGetAttnMapPatch:
|
||||
def __init__(self, ref_target_masks=None):
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x = kwargs["x"]
|
||||
|
||||
if self.ref_target_masks is not None:
|
||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||
transformer_options["x_ref_attn_map"] = x_ref_attn_map
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkCrossAttnPatch:
|
||||
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||
self.model_patch = model_patch
|
||||
@ -412,17 +427,16 @@ class MultiTalkCrossAttnPatch:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
audio_embeds = transformer_options.get("audio_embeds")
|
||||
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
|
||||
|
||||
x_ref_attn_map = None
|
||||
if self.ref_target_masks is not None:
|
||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||
norm_x, audio_embeds.to(x.dtype),
|
||||
shape=transformer_options["grid_sizes"],
|
||||
x_ref_attn_map=x_ref_attn_map
|
||||
)
|
||||
return x_audio * self.audio_scale
|
||||
x = x + x_audio * self.audio_scale
|
||||
return x
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
@ -1289,7 +1289,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, project_audio_features
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
|
||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1429,7 +1429,9 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
is_extend=previous_frames is not None,
|
||||
))
|
||||
# add cross-attention patch
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale, ref_target_masks=token_ref_target_masks), "cross_attn")
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "cross_attn")
|
||||
if token_ref_target_masks is not None:
|
||||
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "self_attn")
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user