mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
Support hunyuan image distilled model. (#9807)
This commit is contained in:
parent
72212fef66
commit
e01e99d075
@ -41,6 +41,7 @@ class HunyuanVideoParams:
|
|||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
byt5: bool
|
byt5: bool
|
||||||
|
meanflow: bool
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionRef(nn.Module):
|
class SelfAttentionRef(nn.Module):
|
||||||
@ -256,6 +257,11 @@ class HunyuanVideo(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.byt5_in = None
|
self.byt5_in = None
|
||||||
|
|
||||||
|
if params.meanflow:
|
||||||
|
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.time_r_in = None
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
@ -282,6 +288,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||||
|
|
||||||
|
if self.time_r_in is not None:
|
||||||
|
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
||||||
|
if len(w) > 0:
|
||||||
|
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||||
|
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||||
|
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||||
|
vec = (vec + vec_r) / 2
|
||||||
|
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
ref_latent_ids = self.img_ids(ref_latent)
|
ref_latent_ids = self.img_ids(ref_latent)
|
||||||
ref_latent = self.img_in(ref_latent)
|
ref_latent = self.img_in(ref_latent)
|
||||||
|
|||||||
@ -142,12 +142,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
|
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
|
||||||
dit_config["patch_size"] = list(in_w.shape[2:])
|
dit_config["patch_size"] = list(in_w.shape[2:])
|
||||||
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
|
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
|
||||||
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
|
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
|
||||||
dit_config["vec_in_dim"] = 768
|
dit_config["vec_in_dim"] = 768
|
||||||
dit_config["axes_dim"] = [16, 56, 56]
|
|
||||||
else:
|
else:
|
||||||
dit_config["vec_in_dim"] = None
|
dit_config["vec_in_dim"] = None
|
||||||
|
|
||||||
|
if len(dit_config["patch_size"]) == 2:
|
||||||
dit_config["axes_dim"] = [64, 64]
|
dit_config["axes_dim"] = [64, 64]
|
||||||
|
else:
|
||||||
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
|
|
||||||
|
if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
|
||||||
|
dit_config["meanflow"] = True
|
||||||
|
else:
|
||||||
|
dit_config["meanflow"] = False
|
||||||
|
|
||||||
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
|
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["hidden_size"] = in_w.shape[0]
|
dit_config["hidden_size"] = in_w.shape[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user