Add accelerate

This commit is contained in:
kijai 2024-10-23 16:05:19 +03:00
parent 34e029bacc
commit db87f8e608
3 changed files with 70 additions and 41 deletions

View File

@ -13,7 +13,16 @@ from torch import nn
from .dit.joint_model.context_parallel import get_cp_rank_size from .dit.joint_model.context_parallel import get_cp_rank_size
from .utils import Timer from .utils import Timer
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar from comfy.utils import ProgressBar, load_torch_file
from contextlib import nullcontext
try:
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
is_accelerate_available = True
except:
is_accelerate_available = False
pass
MAX_T5_TOKEN_LENGTH = 256 MAX_T5_TOKEN_LENGTH = 256
@ -138,29 +147,33 @@ class T2VSynthMochiModel:
rope_theta=10000.0, rope_theta=10000.0,
) )
with t("dit_load_checkpoint"): with t("dit_load_checkpoint"):
model.load_state_dict(load_file(dit_checkpoint_path))
#with t("fsdp_dit"):
self.dit = model
self.dit.eval()
for name, param in self.dit.named_parameters():
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
dit_sd = load_torch_file(dit_checkpoint_path)
if is_accelerate_available:
for name, param in model.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
else:
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
else:
model.load_state_dict(dit_sd)
for name, param in self.dit.named_parameters():
if not any(keyword in name for keyword in params_to_keep): if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(weight_dtype) param.data = param.data.to(weight_dtype)
else: else:
param.data = param.data.to(torch.bfloat16) param.data = param.data.to(torch.bfloat16)
self.dit = model
self.dit.eval()
vae_stats = json.load(open(vae_stats_path)) vae_stats = json.load(open(vae_stats_path))
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
t.print_stats() #t.print_stats()
def get_conditioning(self, prompts, *, zero_last_n_prompts: int): def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
B = len(prompts) B = len(prompts)
print(f"Getting conditioning for {B} prompts")
assert ( assert (
0 <= zero_last_n_prompts <= B 0 <= zero_last_n_prompts <= B
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}" ), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
@ -198,8 +211,6 @@ class T2VSynthMochiModel:
caption_input_ids_t5, caption_attention_mask_t5 caption_input_ids_t5, caption_attention_mask_t5
).last_hidden_state.detach().to(torch.float32) ).last_hidden_state.detach().to(torch.float32)
) )
print(y_feat.shape)
print(y_feat[0])
self.t5_enc.to("cpu") self.t5_enc.to("cpu")
# Sometimes returns a tensor, othertimes a tuple, not sure why # Sometimes returns a tensor, othertimes a tuple, not sure why
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
@ -298,13 +309,13 @@ class T2VSynthMochiModel:
# print(type(sample["y_feat"])) # print(type(sample["y_feat"]))
# print(sample["y_feat"][0].shape) # print(sample["y_feat"][0].shape)
print(sample_null["y_mask"]) # print(sample_null["y_mask"])
print(type(sample_null["y_mask"])) # print(type(sample_null["y_mask"]))
print(sample_null["y_mask"][0].shape) # print(sample_null["y_mask"][0].shape)
print(sample_null["y_feat"]) # print(sample_null["y_feat"])
print(type(sample_null["y_feat"])) # print(type(sample_null["y_feat"]))
print(sample_null["y_feat"][0].shape) # print(sample_null["y_feat"][0].shape)
sample["packed_indices"] = self.get_packed_indices( sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims sample["y_mask"], **latent_dims
@ -359,5 +370,5 @@ class T2VSynthMochiModel:
self.dit.to("cpu") self.dit.to("cpu")
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
print("samples", samples.shape, samples.dtype, samples.device) print("samples: ", samples.shape, samples.dtype, samples.device)
return samples return samples

View File

@ -12,6 +12,15 @@ log = logging.getLogger(__name__)
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
from .mochi_preview.vae.model import Decoder from .mochi_preview.vae.model import Decoder
from contextlib import nullcontext
try:
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
is_accelerate_available = True
except:
is_accelerate_available = False
pass
script_directory = os.path.dirname(os.path.abspath(__file__)) script_directory = os.path.dirname(os.path.abspath(__file__))
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
@ -40,11 +49,13 @@ class DownloadAndLoadMochiModel:
[ [
"mochi_preview_dit_fp8_e4m3fn.safetensors", "mochi_preview_dit_fp8_e4m3fn.safetensors",
], ],
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },
), ),
"vae": ( "vae": (
[ [
"mochi_preview_vae_bf16.safetensors", "mochi_preview_vae_bf16.safetensors",
], ],
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
), ),
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], "precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
{"default": "fp8_e4m3fn", } {"default": "fp8_e4m3fn", }
@ -101,6 +112,7 @@ class DownloadAndLoadMochiModel:
dit_checkpoint_path=model_path, dit_checkpoint_path=model_path,
weight_dtype=dtype, weight_dtype=dtype,
) )
with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder( vae = Decoder(
out_channels=3, out_channels=3,
base_channels=128, base_channels=128,
@ -116,10 +128,14 @@ class DownloadAndLoadMochiModel:
output_nonlinearity="silu", output_nonlinearity="silu",
causal=True, causal=True,
) )
decoder_sd = load_torch_file(vae_path) vae_sd = load_torch_file(vae_path)
vae.load_state_dict(decoder_sd, strict=True) if is_accelerate_available:
for key in vae_sd:
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
else:
vae.load_state_dict(vae_sd, strict=True)
vae.eval().to(torch.bfloat16).to("cpu") vae.eval().to(torch.bfloat16).to("cpu")
del decoder_sd del vae_sd
return (model, vae,) return (model, vae,)

2
requirements.txt Normal file
View File

@ -0,0 +1,2 @@
accelerate
einops