Update t2v_synth_mochi.py
This commit is contained in:
parent
bd954ec132
commit
83097a6b63
@ -125,26 +125,26 @@ class T2VSynthMochiModel:
|
|||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
|
|
||||||
print("Initializing model...")
|
print("Initializing model...")
|
||||||
model: nn.Module = torch.nn.utils.skip_init(
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
AsymmDiTJoint,
|
model: nn.Module = AsymmDiTJoint(
|
||||||
depth=48,
|
depth=48,
|
||||||
patch_size=2,
|
patch_size=2,
|
||||||
num_heads=24,
|
num_heads=24,
|
||||||
hidden_size_x=3072,
|
hidden_size_x=3072,
|
||||||
hidden_size_y=1536,
|
hidden_size_y=1536,
|
||||||
mlp_ratio_x=4.0,
|
mlp_ratio_x=4.0,
|
||||||
mlp_ratio_y=4.0,
|
mlp_ratio_y=4.0,
|
||||||
in_channels=12,
|
in_channels=12,
|
||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
qkv_bias=False,
|
qkv_bias=False,
|
||||||
out_bias=True,
|
out_bias=True,
|
||||||
patch_embed_bias=True,
|
patch_embed_bias=True,
|
||||||
timestep_mlp_bias=True,
|
timestep_mlp_bias=True,
|
||||||
timestep_scale=1000.0,
|
timestep_scale=1000.0,
|
||||||
t5_feat_dim=4096,
|
t5_feat_dim=4096,
|
||||||
t5_token_length=256,
|
t5_token_length=256,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user