Basic WIP support for the wan animate model. (#9939)

This commit is contained in:
comfyanonymous 2025-09-19 00:07:17 -07:00 committed by GitHub
parent 711bcf33ee
commit dc95b6acc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 666 additions and 1 deletions

View File

@ -0,0 +1,548 @@
from torch import nn
import torch
from typing import Tuple, Optional
from einops import rearrange
import torch.nn.functional as F
import math
from .model import WanModel, sinusoidal_embedding_1d
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", operations=None, **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = operations.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None, operations=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
self.norm1 = operations.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
self.out_proj = operations.Linear(1024, hidden_dim, **factory_kwargs)
self.norm1 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
def get_norm_layer(norm_layer, operations=None):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return operations.LayerNorm
elif norm_layer == "rms":
return operations.RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
class FaceAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
num_adapter_layers: int = 1,
dtype=None, device=None, operations=None
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.hidden_size = hidden_dim
self.heads_num = heads_num
self.fuser_blocks = nn.ModuleList(
[
FaceBlock(
self.hidden_size,
self.heads_num,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
operations=operations,
**factory_kwargs,
)
for _ in range(num_adapter_layers)
]
)
def forward(
self,
x: torch.Tensor,
motion_embed: torch.Tensor,
idx: int,
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
class FaceBlock(nn.Module):
def __init__(
self,
hidden_size: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
operations=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.scale = qk_scale or head_dim**-0.5
self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear2 = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type, operations=operations)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm_feat = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
def forward(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
# use_context_parallel=False,
) -> torch.Tensor:
B, T, N, C = motion_vec.shape
T_comp = T
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
q = rearrange(q, "B (L S) H D -> (B L) S (H D)", L=T_comp)
attn = optimized_attention(q, k, v, heads=self.heads_num)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0)]
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81
class FusedLeakyReLU(torch.nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None):
super().__init__()
self.bias = torch.nn.Parameter(torch.empty(1, channel, 1, 1, dtype=dtype, device=device))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
class Blur(torch.nn.Module):
def __init__(self, kernel, pad, dtype=None, device=None):
super().__init__()
kernel = torch.tensor(kernel, dtype=dtype, device=device)
kernel = kernel[None, :] * kernel[:, None]
kernel = kernel / kernel.sum()
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad)
#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590
class ScaledLeakyReLU(torch.nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
return F.leaky_relu(input, negative_slope=self.negative_slope)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605
class EqualConv2d(torch.nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(out_channel, in_channel, kernel_size, kernel_size, device=device, dtype=dtype))
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
self.bias = torch.nn.Parameter(torch.empty(out_channel, device=device, dtype=dtype)) if bias else None
def forward(self, input):
if self.bias is None:
bias = None
else:
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype)
return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134
class EqualLinear(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, dtype=None, device=None, operations=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(out_dim, in_dim, device=device, dtype=dtype))
self.bias = torch.nn.Parameter(torch.empty(out_dim, device=device, dtype=dtype)) if bias else None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.bias is None:
bias = None
else:
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul
if self.activation:
out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale)
return fused_leaky_relu(out, bias)
return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654
class ConvLayer(torch.nn.Sequential):
def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, dtype=None, device=None, operations=None):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
layers.append(Blur(blur_kernel, pad=((p + 1) // 2, p // 2)))
stride, padding = 2, 0
else:
stride, padding = 1, kernel_size // 2
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias and not activate, dtype=dtype, device=device, operations=operations))
if activate:
layers.append(FusedLeakyReLU(out_channel) if bias else ScaledLeakyReLU(0.2))
super().__init__(*layers)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704
class ResBlock(torch.nn.Module):
def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3, dtype=dtype, device=device, operations=operations)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, dtype=dtype, device=device, operations=operations)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False, dtype=dtype, device=device, operations=operations)
def forward(self, input):
out = self.conv2(self.conv1(input))
skip = self.skip(input)
return (out + skip) / math.sqrt(2)
class EncoderApp(torch.nn.Module):
def __init__(self, w_dim=512, dtype=None, device=None, operations=None):
super().__init__()
kwargs = {"device": device, "dtype": dtype, "operations": operations}
self.convs = torch.nn.ModuleList([
ConvLayer(3, 32, 1, **kwargs), ResBlock(32, 64, **kwargs),
ResBlock(64, 128, **kwargs), ResBlock(128, 256, **kwargs),
ResBlock(256, 512, **kwargs), ResBlock(512, 512, **kwargs),
ResBlock(512, 512, **kwargs), ResBlock(512, 512, **kwargs),
EqualConv2d(512, w_dim, 4, padding=0, bias=False, **kwargs)
])
def forward(self, x):
h = x
for conv in self.convs:
h = conv(h)
return h.squeeze(-1).squeeze(-1)
class Encoder(torch.nn.Module):
def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None):
super().__init__()
self.net_app = EncoderApp(dim, dtype=dtype, device=device, operations=operations)
self.fc = torch.nn.Sequential(*[EqualLinear(dim, dim, dtype=dtype, device=device, operations=operations) for _ in range(4)] + [EqualLinear(dim, motion_dim, dtype=dtype, device=device, operations=operations)])
def encode_motion(self, x):
return self.fc(self.net_app(x))
class Direction(torch.nn.Module):
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(512, motion_dim, device=device, dtype=dtype))
self.motion_dim = motion_dim
def forward(self, input):
stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype)
Q, _ = torch.linalg.qr(stabilized_weight.float())
if input is None:
return Q
return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1)
class Synthesis(torch.nn.Module):
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
super().__init__()
self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations)
class Generator(torch.nn.Module):
def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None):
super().__init__()
self.enc = Encoder(style_dim, motion_dim, dtype=dtype, device=device, operations=operations)
self.dec = Synthesis(motion_dim, dtype=dtype, device=device, operations=operations)
def get_motion(self, img):
motion_feat = self.enc.encode_motion(img)
return self.dec.direction(motion_feat)
class AnimateWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='animate',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
motion_encoder_dim=512,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
self.pose_patch_embedding = operations.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
)
self.motion_encoder = Generator(style_dim=512, motion_dim=20, device=device, dtype=dtype, operations=operations)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
device=device, dtype=dtype, operations=operations
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
device=device, dtype=dtype, operations=operations
)
def after_patch_embedding(self, x, pose_latents, face_pixel_values):
if pose_latents is not None:
pose_latents = self.pose_patch_embedding(pose_latents)
x[:, :, 1:] += pose_latents
if face_pixel_values is None:
return x, None
b, c, T, h, w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * encode_bs: (i + 1) * encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
if motion_vec.shape[1] < x.shape[2]:
B, L, H, C = motion_vec.shape
pad = torch.zeros(B, x.shape[2] - motion_vec.shape[1], H, C).type_as(motion_vec)
motion_vec = torch.cat([motion_vec, pad], dim=1)
else:
motion_vec = motion_vec[:, :x.shape[2]]
return x, motion_vec
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
pose_latents=None,
face_pixel_values=None,
freqs=None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
full_ref = None
if self.ref_conv is not None:
full_ref = kwargs.get("reference_latent", None)
if full_ref is not None:
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
if i % 5 == 0 and motion_vec is not None:
x = x + self.face_adapter.fuser_blocks[i // 5](x, motion_vec)
# head
x = self.head(x, e)
if full_ref is not None:
x = x[:, full_ref.shape[1]:]
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@ -39,6 +39,7 @@ import comfy.ldm.cosmos.model
import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
@ -1253,6 +1254,23 @@ class WAN21_HuMo(WAN21):
return out
class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
face_video_pixels = kwargs.get("face_video_pixels", None)
if face_video_pixels is not None:
out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels)
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)

View File

@ -404,6 +404,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "s2v"
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@ -1096,6 +1096,19 @@ class WAN22_S2V(WAN21_T2V):
out = model_base.WAN22_S2V(self, device=device)
return out
class WAN22_Animate(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "animate",
}
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22_Animate(self, device=device)
return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1361,6 +1374,6 @@ class HunyuanImage21Refiner(HunyuanVideo):
out = model_base.HunyuanImage21Refiner(self, device=device)
return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -1108,6 +1108,89 @@ class WanHuMoImageToVideo(io.ComfyNode):
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanAnimateToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanAnimateToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("face_video", optional=True),
io.Image.Input("pose_video", optional=True),
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Image.Input("continue_motion", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_latent"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput:
latent_length = ((length - 1) // 4) + 1
latent_width = width // 8
latent_height = height // 8
trim_latent = 0
if reference_image is None:
reference_image = torch.zeros((1, height, width, 3))
image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
trim_latent += concat_latent_image.shape[2]
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if face_video is not None:
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
face_video = face_video.movedim(0, 1).unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent})
negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent})
if continue_motion is None:
image = torch.ones((length, height, width, 3)) * 0.5
else:
continue_motion = continue_motion[-continue_motion_max_frames:]
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
image[:continue_motion.shape[0]] = continue_motion
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
if continue_motion is not None:
mask_refmotion[:, :, :((continue_motion.shape[0] - 1) // 4) + 1] = 0.0
mask = torch.cat((mask, mask_refmotion), dim=2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device())
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent, trim_latent)
class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -1169,6 +1252,7 @@ class WanExtension(ComfyExtension):
WanSoundImageToVideo,
WanSoundImageToVideoExtend,
WanHuMoImageToVideo,
WanAnimateToVideo,
Wan22ImageToVideoLatent,
]