Update t2v_synth_mochi.py
This commit is contained in:
parent
bd954ec132
commit
83097a6b63
@ -125,26 +125,26 @@ class T2VSynthMochiModel:
|
||||
self.offload_device = offload_device
|
||||
|
||||
print("Initializing model...")
|
||||
model: nn.Module = torch.nn.utils.skip_init(
|
||||
AsymmDiTJoint,
|
||||
depth=48,
|
||||
patch_size=2,
|
||||
num_heads=24,
|
||||
hidden_size_x=3072,
|
||||
hidden_size_y=1536,
|
||||
mlp_ratio_x=4.0,
|
||||
mlp_ratio_y=4.0,
|
||||
in_channels=12,
|
||||
qk_norm=True,
|
||||
qkv_bias=False,
|
||||
out_bias=True,
|
||||
patch_embed_bias=True,
|
||||
timestep_mlp_bias=True,
|
||||
timestep_scale=1000.0,
|
||||
t5_feat_dim=4096,
|
||||
t5_token_length=256,
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
model: nn.Module = AsymmDiTJoint(
|
||||
depth=48,
|
||||
patch_size=2,
|
||||
num_heads=24,
|
||||
hidden_size_x=3072,
|
||||
hidden_size_y=1536,
|
||||
mlp_ratio_x=4.0,
|
||||
mlp_ratio_y=4.0,
|
||||
in_channels=12,
|
||||
qk_norm=True,
|
||||
qkv_bias=False,
|
||||
out_bias=True,
|
||||
patch_embed_bias=True,
|
||||
timestep_mlp_bias=True,
|
||||
timestep_scale=1000.0,
|
||||
t5_feat_dim=4096,
|
||||
t5_token_length=256,
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user