mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-14 16:34:36 +08:00
Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn
This commit is contained in:
parent
7237c36fba
commit
fb099a40b2
@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
|
|||||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
|
|
||||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
|
|
||||||
def qkv_fn_q(x):
|
def qkv_fn_q(x):
|
||||||
@ -86,8 +88,12 @@ class WanSelfAttention(nn.Module):
|
|||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "self_attn" in patches:
|
||||||
|
for p in patches["self_attn"]:
|
||||||
|
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
return x, q, k
|
return x
|
||||||
|
|
||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
@ -234,7 +240,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y, q, k = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs, transformer_options=transformer_options)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
@ -246,7 +252,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
if "cross_attn" in patches:
|
if "cross_attn" in patches:
|
||||||
for p in patches["cross_attn"]:
|
for p in patches["cross_attn"]:
|
||||||
x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
x = p({"x": x, "transformer_options": transformer_options})
|
||||||
|
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
|
|||||||
@ -391,6 +391,7 @@ class MultiTalkAudioProjModel(torch.nn.Module):
|
|||||||
|
|
||||||
return context_tokens
|
return context_tokens
|
||||||
|
|
||||||
|
|
||||||
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||||
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -398,6 +399,20 @@ class WanMultiTalkAttentionBlock(torch.nn.Module):
|
|||||||
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTalkGetAttnMapPatch:
|
||||||
|
def __init__(self, ref_target_masks=None):
|
||||||
|
self.ref_target_masks = ref_target_masks
|
||||||
|
|
||||||
|
def __call__(self, kwargs):
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
x = kwargs["x"]
|
||||||
|
|
||||||
|
if self.ref_target_masks is not None:
|
||||||
|
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||||
|
transformer_options["x_ref_attn_map"] = x_ref_attn_map
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class MultiTalkCrossAttnPatch:
|
class MultiTalkCrossAttnPatch:
|
||||||
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||||
self.model_patch = model_patch
|
self.model_patch = model_patch
|
||||||
@ -412,17 +427,16 @@ class MultiTalkCrossAttnPatch:
|
|||||||
return torch.zeros_like(x)
|
return torch.zeros_like(x)
|
||||||
|
|
||||||
audio_embeds = transformer_options.get("audio_embeds")
|
audio_embeds = transformer_options.get("audio_embeds")
|
||||||
|
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
|
||||||
|
|
||||||
x_ref_attn_map = None
|
|
||||||
if self.ref_target_masks is not None:
|
|
||||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
|
||||||
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||||
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||||
norm_x, audio_embeds.to(x.dtype),
|
norm_x, audio_embeds.to(x.dtype),
|
||||||
shape=transformer_options["grid_sizes"],
|
shape=transformer_options["grid_sizes"],
|
||||||
x_ref_attn_map=x_ref_attn_map
|
x_ref_attn_map=x_ref_attn_map
|
||||||
)
|
)
|
||||||
return x_audio * self.audio_scale
|
x = x + x_audio * self.audio_scale
|
||||||
|
return x
|
||||||
|
|
||||||
def models(self):
|
def models(self):
|
||||||
return [self.model_patch]
|
return [self.model_patch]
|
||||||
|
|||||||
@ -1289,7 +1289,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
|||||||
return io.NodeOutput(out_latent)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
|
|
||||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, project_audio_features
|
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
|
||||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -1429,7 +1429,9 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
is_extend=previous_frames is not None,
|
is_extend=previous_frames is not None,
|
||||||
))
|
))
|
||||||
# add cross-attention patch
|
# add cross-attention patch
|
||||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale, ref_target_masks=token_ref_target_masks), "cross_attn")
|
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "cross_attn")
|
||||||
|
if token_ref_target_masks is not None:
|
||||||
|
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "self_attn")
|
||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user