From db87f8e6084318704d982dc2396113e1474a0538 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:05:19 +0300 Subject: [PATCH] Add accelerate --- mochi_preview/t2v_synth_mochi.py | 55 +++++++++++++++++++------------- nodes.py | 54 ++++++++++++++++++++----------- requirements.txt | 2 ++ 3 files changed, 70 insertions(+), 41 deletions(-) create mode 100644 requirements.txt diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 15fb287..f696af0 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -13,7 +13,16 @@ from torch import nn from .dit.joint_model.context_parallel import get_cp_rank_size from .utils import Timer from tqdm import tqdm -from comfy.utils import ProgressBar +from comfy.utils import ProgressBar, load_torch_file + +from contextlib import nullcontext +try: + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + is_accelerate_available = True +except: + is_accelerate_available = False + pass MAX_T5_TOKEN_LENGTH = 256 @@ -138,29 +147,33 @@ class T2VSynthMochiModel: rope_theta=10000.0, ) with t("dit_load_checkpoint"): - - model.load_state_dict(load_file(dit_checkpoint_path)) + params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} + dit_sd = load_torch_file(dit_checkpoint_path) + if is_accelerate_available: + for name, param in model.named_parameters(): + if not any(keyword in name for keyword in params_to_keep): + set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name]) + else: + set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name]) + else: + model.load_state_dict(dit_sd) + for name, param in self.dit.named_parameters(): + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(weight_dtype) + else: + param.data = param.data.to(torch.bfloat16) - #with t("fsdp_dit"): self.dit = model self.dit.eval() - for name, param in self.dit.named_parameters(): - params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} - if not any(keyword in name for keyword in params_to_keep): - param.data = param.data.to(weight_dtype) - else: - param.data = param.data.to(torch.bfloat16) - vae_stats = json.load(open(vae_stats_path)) self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) - t.print_stats() + #t.print_stats() def get_conditioning(self, prompts, *, zero_last_n_prompts: int): B = len(prompts) - print(f"Getting conditioning for {B} prompts") assert ( 0 <= zero_last_n_prompts <= B ), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}" @@ -198,8 +211,6 @@ class T2VSynthMochiModel: caption_input_ids_t5, caption_attention_mask_t5 ).last_hidden_state.detach().to(torch.float32) ) - print(y_feat.shape) - print(y_feat[0]) self.t5_enc.to("cpu") # Sometimes returns a tensor, othertimes a tuple, not sure why # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 @@ -298,13 +309,13 @@ class T2VSynthMochiModel: # print(type(sample["y_feat"])) # print(sample["y_feat"][0].shape) - print(sample_null["y_mask"]) - print(type(sample_null["y_mask"])) - print(sample_null["y_mask"][0].shape) + # print(sample_null["y_mask"]) + # print(type(sample_null["y_mask"])) + # print(sample_null["y_mask"][0].shape) - print(sample_null["y_feat"]) - print(type(sample_null["y_feat"])) - print(sample_null["y_feat"][0].shape) + # print(sample_null["y_feat"]) + # print(type(sample_null["y_feat"])) + # print(sample_null["y_feat"][0].shape) sample["packed_indices"] = self.get_packed_indices( sample["y_mask"], **latent_dims @@ -359,5 +370,5 @@ class T2VSynthMochiModel: self.dit.to("cpu") samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) - print("samples", samples.shape, samples.dtype, samples.device) + print("samples: ", samples.shape, samples.dtype, samples.device) return samples diff --git a/nodes.py b/nodes.py index 04ae75f..8fcb47a 100644 --- a/nodes.py +++ b/nodes.py @@ -12,6 +12,15 @@ log = logging.getLogger(__name__) from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel from .mochi_preview.vae.model import Decoder +from contextlib import nullcontext +try: + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + is_accelerate_available = True +except: + is_accelerate_available = False + pass + script_directory = os.path.dirname(os.path.abspath(__file__)) def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): @@ -40,11 +49,13 @@ class DownloadAndLoadMochiModel: [ "mochi_preview_dit_fp8_e4m3fn.safetensors", ], + {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", }, ), "vae": ( [ "mochi_preview_vae_bf16.safetensors", ], + {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", }, ), "precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], {"default": "fp8_e4m3fn", } @@ -101,25 +112,30 @@ class DownloadAndLoadMochiModel: dit_checkpoint_path=model_path, weight_dtype=dtype, ) - vae = Decoder( - out_channels=3, - base_channels=128, - channel_multipliers=[1, 2, 4, 6], - temporal_expansions=[1, 2, 3], - spatial_expansions=[2, 2, 2], - num_res_blocks=[3, 3, 4, 6, 3], - latent_dim=12, - has_attention=[False, False, False, False, False], - padding_mode="replicate", - output_norm=False, - nonlinearity="silu", - output_nonlinearity="silu", - causal=True, - ) - decoder_sd = load_torch_file(vae_path) - vae.load_state_dict(decoder_sd, strict=True) - vae.eval().to(torch.bfloat16).to("cpu") - del decoder_sd + with (init_empty_weights() if is_accelerate_available else nullcontext()): + vae = Decoder( + out_channels=3, + base_channels=128, + channel_multipliers=[1, 2, 4, 6], + temporal_expansions=[1, 2, 3], + spatial_expansions=[2, 2, 2], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + has_attention=[False, False, False, False, False], + padding_mode="replicate", + output_norm=False, + nonlinearity="silu", + output_nonlinearity="silu", + causal=True, + ) + vae_sd = load_torch_file(vae_path) + if is_accelerate_available: + for key in vae_sd: + set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key]) + else: + vae.load_state_dict(vae_sd, strict=True) + vae.eval().to(torch.bfloat16).to("cpu") + del vae_sd return (model, vae,) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b4f81fc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +accelerate +einops \ No newline at end of file