Update t2v_synth_mochi.py
This commit is contained in:
parent
508eaa22df
commit
bd954ec132
@ -10,7 +10,6 @@ import torch.utils.data
|
||||
from torch import nn
|
||||
|
||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from .utils import Timer
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
|
||||
@ -122,12 +121,10 @@ class T2VSynthMochiModel:
|
||||
fp8_fastmode: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
t = Timer()
|
||||
self.device = device
|
||||
self.offload_device = offload_device
|
||||
|
||||
with t("construct_dit"):
|
||||
|
||||
print("Initializing model...")
|
||||
model: nn.Module = torch.nn.utils.skip_init(
|
||||
AsymmDiTJoint,
|
||||
depth=48,
|
||||
@ -148,7 +145,7 @@ class T2VSynthMochiModel:
|
||||
t5_token_length=256,
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
with t("dit_load_checkpoint"):
|
||||
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||
@ -179,8 +176,6 @@ class T2VSynthMochiModel:
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
t.print_stats()
|
||||
|
||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||
B = len(prompts)
|
||||
assert (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user