Update t2v_synth_mochi.py

This commit is contained in:
kijai 2024-10-24 00:36:38 +03:00
parent 508eaa22df
commit bd954ec132

View File

@ -10,7 +10,6 @@ import torch.utils.data
from torch import nn from torch import nn
from .dit.joint_model.context_parallel import get_cp_rank_size from .dit.joint_model.context_parallel import get_cp_rank_size
from .utils import Timer
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
@ -122,55 +121,53 @@ class T2VSynthMochiModel:
fp8_fastmode: bool = False, fp8_fastmode: bool = False,
): ):
super().__init__() super().__init__()
t = Timer()
self.device = device self.device = device
self.offload_device = offload_device self.offload_device = offload_device
with t("construct_dit"): print("Initializing model...")
model: nn.Module = torch.nn.utils.skip_init(
AsymmDiTJoint,
depth=48,
patch_size=2,
num_heads=24,
hidden_size_x=3072,
hidden_size_y=1536,
mlp_ratio_x=4.0,
mlp_ratio_y=4.0,
in_channels=12,
qk_norm=True,
qkv_bias=False,
out_bias=True,
patch_embed_bias=True,
timestep_mlp_bias=True,
timestep_scale=1000.0,
t5_feat_dim=4096,
t5_token_length=256,
rope_theta=10000.0,
)
model: nn.Module = torch.nn.utils.skip_init( params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
AsymmDiTJoint, print(f"Loading model state_dict from {dit_checkpoint_path}...")
depth=48, dit_sd = load_torch_file(dit_checkpoint_path)
patch_size=2, if is_accelerate_available:
num_heads=24, print("Using accelerate to load and assign model weights to device...")
hidden_size_x=3072, for name, param in model.named_parameters():
hidden_size_y=1536, if not any(keyword in name for keyword in params_to_keep):
mlp_ratio_x=4.0, set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
mlp_ratio_y=4.0, else:
in_channels=12, set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
qk_norm=True, else:
qkv_bias=False, print("Loading state_dict without accelerate...")
out_bias=True, model.load_state_dict(dit_sd)
patch_embed_bias=True, for name, param in model.named_parameters():
timestep_mlp_bias=True, if not any(keyword in name for keyword in params_to_keep):
timestep_scale=1000.0, param.data = param.data.to(weight_dtype)
t5_feat_dim=4096, else:
t5_token_length=256, param.data = param.data.to(torch.bfloat16)
rope_theta=10000.0,
)
with t("dit_load_checkpoint"):
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
print(f"Loading model state_dict from {dit_checkpoint_path}...")
dit_sd = load_torch_file(dit_checkpoint_path)
if is_accelerate_available:
print("Using accelerate to load and assign model weights to device...")
for name, param in model.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
else:
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
else:
print("Loading state_dict without accelerate...")
model.load_state_dict(dit_sd)
for name, param in model.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(weight_dtype)
else:
param.data = param.data.to(torch.bfloat16)
if fp8_fastmode: if fp8_fastmode:
from ..fp8_optimization import convert_fp8_linear from ..fp8_optimization import convert_fp8_linear
convert_fp8_linear(model, torch.bfloat16) convert_fp8_linear(model, torch.bfloat16)
self.dit = model self.dit = model
self.dit.eval() self.dit.eval()
@ -179,8 +176,6 @@ class T2VSynthMochiModel:
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
t.print_stats()
def get_conditioning(self, prompts, *, zero_last_n_prompts: int): def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
B = len(prompts) B = len(prompts)
assert ( assert (