Update t2v_synth_mochi.py

This commit is contained in:
kijai 2024-10-24 00:40:52 +03:00
parent bd954ec132
commit 83097a6b63

View File

@ -125,26 +125,26 @@ class T2VSynthMochiModel:
self.offload_device = offload_device
print("Initializing model...")
model: nn.Module = torch.nn.utils.skip_init(
AsymmDiTJoint,
depth=48,
patch_size=2,
num_heads=24,
hidden_size_x=3072,
hidden_size_y=1536,
mlp_ratio_x=4.0,
mlp_ratio_y=4.0,
in_channels=12,
qk_norm=True,
qkv_bias=False,
out_bias=True,
patch_embed_bias=True,
timestep_mlp_bias=True,
timestep_scale=1000.0,
t5_feat_dim=4096,
t5_token_length=256,
rope_theta=10000.0,
)
with (init_empty_weights() if is_accelerate_available else nullcontext()):
model: nn.Module = AsymmDiTJoint(
depth=48,
patch_size=2,
num_heads=24,
hidden_size_x=3072,
hidden_size_y=1536,
mlp_ratio_x=4.0,
mlp_ratio_y=4.0,
in_channels=12,
qk_norm=True,
qkv_bias=False,
out_bias=True,
patch_embed_bias=True,
timestep_mlp_bias=True,
timestep_scale=1000.0,
t5_feat_dim=4096,
t5_token_length=256,
rope_theta=10000.0,
)
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
print(f"Loading model state_dict from {dit_checkpoint_path}...")