mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Support the HuMo model. (#9903)
This commit is contained in:
parent
e42682b24e
commit
9288c78fc5
@ -41,6 +41,7 @@ class AudioEncoderModel():
|
||||
outputs = {}
|
||||
outputs["encoded_audio"] = out
|
||||
outputs["encoded_audio_all_layers"] = all_layers
|
||||
outputs["audio_samples"] = audio.shape[2]
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module):
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
eps=1e-6,
|
||||
kv_dim=None,
|
||||
operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@ -43,11 +45,13 @@ class WanSelfAttention(nn.Module):
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
if kv_dim is None:
|
||||
kv_dim = dim
|
||||
|
||||
# layers
|
||||
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
@ -402,6 +406,7 @@ class WanModel(torch.nn.Module):
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
in_dim_ref_conv=None,
|
||||
wan_attn_block_class=WanAttentionBlock,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -479,8 +484,8 @@ class WanModel(torch.nn.Module):
|
||||
# blocks
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
@ -1325,3 +1330,247 @@ class WanModel_S2V(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
|
||||
class WanT2VCrossAttentionGather(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C] - video tokens
|
||||
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
|
||||
"""
|
||||
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(context))
|
||||
v = self.v(context)
|
||||
|
||||
# Handle audio temporal structure (16 tokens per frame)
|
||||
k = k.reshape(-1, 16, n, d).transpose(1, 2)
|
||||
v = v.reshape(-1, 16, n, d).transpose(1, 2)
|
||||
|
||||
# Handle video spatial structure
|
||||
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
|
||||
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class AudioCrossAttentionWrapper(nn.Module):
|
||||
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
|
||||
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings)
|
||||
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x, audio, transformer_options={}):
|
||||
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
class WanAttentionBlockAudio(WanAttentionBlock):
|
||||
|
||||
def __init__(self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
|
||||
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
freqs,
|
||||
context,
|
||||
context_img_len=257,
|
||||
audio=None,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
e(Tensor): Shape [B, 6, C]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if audio is not None:
|
||||
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
class DummyAdapterLayer(nn.Module):
|
||||
def __init__(self, layer):
|
||||
super().__init__()
|
||||
self.layer = layer
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.layer(*args, **kwargs)
|
||||
|
||||
|
||||
class AudioProjModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len=5,
|
||||
blocks=13, # add a new parameter blocks
|
||||
channels=768, # add a new parameter channels
|
||||
intermediate_dim=512,
|
||||
output_dim=1536,
|
||||
context_tokens=16,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.output_dim = output_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
|
||||
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
|
||||
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
|
||||
|
||||
self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, audio_embeds):
|
||||
video_length = audio_embeds.shape[1]
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
|
||||
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
|
||||
|
||||
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
|
||||
|
||||
context_tokens = self.audio_proj_glob_norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class HumoWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='humo',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
image_model=None,
|
||||
audio_token_num=16,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
freqs=None,
|
||||
audio_embed=None,
|
||||
reference_latent=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
bs, _, time, height, width = x.shape
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
if reference_latent is not None:
|
||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||
ref = ref.flatten(2).transpose(1, 2)
|
||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
|
||||
x = torch.cat([x, ref], dim=1)
|
||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||
del ref, freqs_ref
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
context_img_len = None
|
||||
|
||||
if audio_embed is not None:
|
||||
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
|
||||
else:
|
||||
audio = None
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@ -1213,6 +1213,23 @@ class WAN21_Camera(WAN21):
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class WAN21_HuMo(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
audio_embed = kwargs.get("audio_embed", None)
|
||||
if audio_embed is not None:
|
||||
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
|
||||
return out
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
|
||||
@ -402,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "camera_2.2"
|
||||
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "s2v"
|
||||
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "humo"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
|
||||
@ -1073,6 +1073,16 @@ class WAN21_Vace(WAN21_T2V):
|
||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_HuMo(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "humo",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_S2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1351,6 +1361,6 @@ class HunyuanImage21Refiner(HunyuanVideo):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -1015,6 +1015,103 @@ class WanSoundImageToVideoExtend(io.ComfyNode):
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2):
|
||||
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
|
||||
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
|
||||
iter_ = 1 + (frame_num - 1) // 4
|
||||
audio_emb_wind = []
|
||||
for lt_i in range(iter_):
|
||||
if lt_i == 0:
|
||||
st = frame0_idx + lt_i - 2
|
||||
ed = frame0_idx + lt_i + 3
|
||||
wind_feat = torch.stack([
|
||||
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
||||
for i in range(st, ed)
|
||||
], dim=0)
|
||||
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
|
||||
else:
|
||||
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
|
||||
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
|
||||
wind_feat = torch.stack([
|
||||
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
||||
for i in range(st, ed)
|
||||
], dim=0)
|
||||
audio_emb_wind.append(wind_feat)
|
||||
audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
|
||||
|
||||
return audio_emb_wind, ed - audio_shift
|
||||
|
||||
|
||||
class WanHuMoImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanHuMoImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
|
||||
io.Image.Input("ref_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None) -> io.NodeOutput:
|
||||
latent_t = ((length - 1) // 4) + 1
|
||||
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if ref_image is not None:
|
||||
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||
else:
|
||||
zero_latent = torch.zeros([batch_size, 16, 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [zero_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [zero_latent]}, append=True)
|
||||
|
||||
if audio_encoder_output is not None:
|
||||
audio_emb = torch.stack(audio_encoder_output["encoded_audio_all_layers"], dim=2)
|
||||
audio_len = audio_encoder_output["audio_samples"] // 640
|
||||
audio_emb = audio_emb[:, :audio_len * 2]
|
||||
|
||||
feat0 = linear_interpolation(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
|
||||
feat1 = linear_interpolation(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
|
||||
feat2 = linear_interpolation(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
|
||||
feat3 = linear_interpolation(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
|
||||
feat4 = linear_interpolation(audio_emb[:, :, 32], 50, 25)
|
||||
audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
|
||||
audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0)
|
||||
|
||||
# pad for ref latent
|
||||
zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype)
|
||||
audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
|
||||
|
||||
audio_emb = audio_emb.unsqueeze(0)
|
||||
audio_emb_neg = torch.zeros_like(audio_emb)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_emb_neg})
|
||||
else:
|
||||
zero_audio = torch.zeros([batch_size, latent_t + 1, 8, 5, 1280], device=comfy.model_management.intermediate_device())
|
||||
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": zero_audio})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": zero_audio})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1075,6 +1172,7 @@ class WanExtension(ComfyExtension):
|
||||
WanPhantomSubjectToVideo,
|
||||
WanSoundImageToVideo,
|
||||
WanSoundImageToVideoExtend,
|
||||
WanHuMoImageToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user