Update t2v_synth_mochi.py
This commit is contained in:
parent
508eaa22df
commit
bd954ec132
@ -10,7 +10,6 @@ import torch.utils.data
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||||
from .utils import Timer
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from comfy.utils import ProgressBar, load_torch_file
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
|
|
||||||
@ -122,55 +121,53 @@ class T2VSynthMochiModel:
|
|||||||
fp8_fastmode: bool = False,
|
fp8_fastmode: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
t = Timer()
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
|
|
||||||
with t("construct_dit"):
|
print("Initializing model...")
|
||||||
|
model: nn.Module = torch.nn.utils.skip_init(
|
||||||
|
AsymmDiTJoint,
|
||||||
|
depth=48,
|
||||||
|
patch_size=2,
|
||||||
|
num_heads=24,
|
||||||
|
hidden_size_x=3072,
|
||||||
|
hidden_size_y=1536,
|
||||||
|
mlp_ratio_x=4.0,
|
||||||
|
mlp_ratio_y=4.0,
|
||||||
|
in_channels=12,
|
||||||
|
qk_norm=True,
|
||||||
|
qkv_bias=False,
|
||||||
|
out_bias=True,
|
||||||
|
patch_embed_bias=True,
|
||||||
|
timestep_mlp_bias=True,
|
||||||
|
timestep_scale=1000.0,
|
||||||
|
t5_feat_dim=4096,
|
||||||
|
t5_token_length=256,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
model: nn.Module = torch.nn.utils.skip_init(
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
AsymmDiTJoint,
|
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||||
depth=48,
|
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||||
patch_size=2,
|
if is_accelerate_available:
|
||||||
num_heads=24,
|
print("Using accelerate to load and assign model weights to device...")
|
||||||
hidden_size_x=3072,
|
for name, param in model.named_parameters():
|
||||||
hidden_size_y=1536,
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
mlp_ratio_x=4.0,
|
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
||||||
mlp_ratio_y=4.0,
|
else:
|
||||||
in_channels=12,
|
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
||||||
qk_norm=True,
|
else:
|
||||||
qkv_bias=False,
|
print("Loading state_dict without accelerate...")
|
||||||
out_bias=True,
|
model.load_state_dict(dit_sd)
|
||||||
patch_embed_bias=True,
|
for name, param in model.named_parameters():
|
||||||
timestep_mlp_bias=True,
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
timestep_scale=1000.0,
|
param.data = param.data.to(weight_dtype)
|
||||||
t5_feat_dim=4096,
|
else:
|
||||||
t5_token_length=256,
|
param.data = param.data.to(torch.bfloat16)
|
||||||
rope_theta=10000.0,
|
|
||||||
)
|
|
||||||
with t("dit_load_checkpoint"):
|
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
|
||||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
|
||||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
|
||||||
if is_accelerate_available:
|
|
||||||
print("Using accelerate to load and assign model weights to device...")
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
|
||||||
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
|
||||||
else:
|
|
||||||
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
|
||||||
else:
|
|
||||||
print("Loading state_dict without accelerate...")
|
|
||||||
model.load_state_dict(dit_sd)
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
|
||||||
param.data = param.data.to(weight_dtype)
|
|
||||||
else:
|
|
||||||
param.data = param.data.to(torch.bfloat16)
|
|
||||||
|
|
||||||
if fp8_fastmode:
|
if fp8_fastmode:
|
||||||
from ..fp8_optimization import convert_fp8_linear
|
from ..fp8_optimization import convert_fp8_linear
|
||||||
convert_fp8_linear(model, torch.bfloat16)
|
convert_fp8_linear(model, torch.bfloat16)
|
||||||
|
|
||||||
self.dit = model
|
self.dit = model
|
||||||
self.dit.eval()
|
self.dit.eval()
|
||||||
@ -179,8 +176,6 @@ class T2VSynthMochiModel:
|
|||||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||||
|
|
||||||
t.print_stats()
|
|
||||||
|
|
||||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||||
B = len(prompts)
|
B = len(prompts)
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user