Update t2v_synth_mochi.py

This commit is contained in:
kijai 2024-10-24 00:40:52 +03:00
parent bd954ec132
commit 83097a6b63

View File

@ -125,26 +125,26 @@ class T2VSynthMochiModel:
self.offload_device = offload_device self.offload_device = offload_device
print("Initializing model...") print("Initializing model...")
model: nn.Module = torch.nn.utils.skip_init( with (init_empty_weights() if is_accelerate_available else nullcontext()):
AsymmDiTJoint, model: nn.Module = AsymmDiTJoint(
depth=48, depth=48,
patch_size=2, patch_size=2,
num_heads=24, num_heads=24,
hidden_size_x=3072, hidden_size_x=3072,
hidden_size_y=1536, hidden_size_y=1536,
mlp_ratio_x=4.0, mlp_ratio_x=4.0,
mlp_ratio_y=4.0, mlp_ratio_y=4.0,
in_channels=12, in_channels=12,
qk_norm=True, qk_norm=True,
qkv_bias=False, qkv_bias=False,
out_bias=True, out_bias=True,
patch_embed_bias=True, patch_embed_bias=True,
timestep_mlp_bias=True, timestep_mlp_bias=True,
timestep_scale=1000.0, timestep_scale=1000.0,
t5_feat_dim=4096, t5_feat_dim=4096,
t5_token_length=256, t5_token_length=256,
rope_theta=10000.0, rope_theta=10000.0,
) )
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
print(f"Loading model state_dict from {dit_checkpoint_path}...") print(f"Loading model state_dict from {dit_checkpoint_path}...")