diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 6507700..2550991 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -10,7 +10,6 @@ import torch.utils.data from torch import nn from .dit.joint_model.context_parallel import get_cp_rank_size -from .utils import Timer from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file @@ -122,55 +121,53 @@ class T2VSynthMochiModel: fp8_fastmode: bool = False, ): super().__init__() - t = Timer() self.device = device self.offload_device = offload_device - with t("construct_dit"): - - model: nn.Module = torch.nn.utils.skip_init( - AsymmDiTJoint, - depth=48, - patch_size=2, - num_heads=24, - hidden_size_x=3072, - hidden_size_y=1536, - mlp_ratio_x=4.0, - mlp_ratio_y=4.0, - in_channels=12, - qk_norm=True, - qkv_bias=False, - out_bias=True, - patch_embed_bias=True, - timestep_mlp_bias=True, - timestep_scale=1000.0, - t5_feat_dim=4096, - t5_token_length=256, - rope_theta=10000.0, - ) - with t("dit_load_checkpoint"): - params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} - print(f"Loading model state_dict from {dit_checkpoint_path}...") - dit_sd = load_torch_file(dit_checkpoint_path) - if is_accelerate_available: - print("Using accelerate to load and assign model weights to device...") - for name, param in model.named_parameters(): - if not any(keyword in name for keyword in params_to_keep): - set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name]) - else: - set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name]) - else: - print("Loading state_dict without accelerate...") - model.load_state_dict(dit_sd) - for name, param in model.named_parameters(): - if not any(keyword in name for keyword in params_to_keep): - param.data = param.data.to(weight_dtype) - else: - param.data = param.data.to(torch.bfloat16) - - if fp8_fastmode: - from ..fp8_optimization import convert_fp8_linear - convert_fp8_linear(model, torch.bfloat16) + print("Initializing model...") + model: nn.Module = torch.nn.utils.skip_init( + AsymmDiTJoint, + depth=48, + patch_size=2, + num_heads=24, + hidden_size_x=3072, + hidden_size_y=1536, + mlp_ratio_x=4.0, + mlp_ratio_y=4.0, + in_channels=12, + qk_norm=True, + qkv_bias=False, + out_bias=True, + patch_embed_bias=True, + timestep_mlp_bias=True, + timestep_scale=1000.0, + t5_feat_dim=4096, + t5_token_length=256, + rope_theta=10000.0, + ) + + params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} + print(f"Loading model state_dict from {dit_checkpoint_path}...") + dit_sd = load_torch_file(dit_checkpoint_path) + if is_accelerate_available: + print("Using accelerate to load and assign model weights to device...") + for name, param in model.named_parameters(): + if not any(keyword in name for keyword in params_to_keep): + set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name]) + else: + set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name]) + else: + print("Loading state_dict without accelerate...") + model.load_state_dict(dit_sd) + for name, param in model.named_parameters(): + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(weight_dtype) + else: + param.data = param.data.to(torch.bfloat16) + + if fp8_fastmode: + from ..fp8_optimization import convert_fp8_linear + convert_fp8_linear(model, torch.bfloat16) self.dit = model self.dit.eval() @@ -179,8 +176,6 @@ class T2VSynthMochiModel: self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) - t.print_stats() - def get_conditioning(self, prompts, *, zero_last_n_prompts: int): B = len(prompts) assert (