Add accelerate
This commit is contained in:
parent
34e029bacc
commit
db87f8e608
@ -13,7 +13,16 @@ from torch import nn
|
||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from .utils import Timer
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
|
||||
from contextlib import nullcontext
|
||||
try:
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
is_accelerate_available = True
|
||||
except:
|
||||
is_accelerate_available = False
|
||||
pass
|
||||
|
||||
MAX_T5_TOKEN_LENGTH = 256
|
||||
|
||||
@ -138,29 +147,33 @@ class T2VSynthMochiModel:
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
with t("dit_load_checkpoint"):
|
||||
|
||||
model.load_state_dict(load_file(dit_checkpoint_path))
|
||||
|
||||
#with t("fsdp_dit"):
|
||||
self.dit = model
|
||||
self.dit.eval()
|
||||
for name, param in self.dit.named_parameters():
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||
if is_accelerate_available:
|
||||
for name, param in model.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
||||
else:
|
||||
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
||||
else:
|
||||
model.load_state_dict(dit_sd)
|
||||
for name, param in self.dit.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
param.data = param.data.to(weight_dtype)
|
||||
else:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
|
||||
self.dit = model
|
||||
self.dit.eval()
|
||||
|
||||
vae_stats = json.load(open(vae_stats_path))
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
t.print_stats()
|
||||
#t.print_stats()
|
||||
|
||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||
B = len(prompts)
|
||||
print(f"Getting conditioning for {B} prompts")
|
||||
assert (
|
||||
0 <= zero_last_n_prompts <= B
|
||||
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
|
||||
@ -198,8 +211,6 @@ class T2VSynthMochiModel:
|
||||
caption_input_ids_t5, caption_attention_mask_t5
|
||||
).last_hidden_state.detach().to(torch.float32)
|
||||
)
|
||||
print(y_feat.shape)
|
||||
print(y_feat[0])
|
||||
self.t5_enc.to("cpu")
|
||||
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
||||
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
||||
@ -298,13 +309,13 @@ class T2VSynthMochiModel:
|
||||
# print(type(sample["y_feat"]))
|
||||
# print(sample["y_feat"][0].shape)
|
||||
|
||||
print(sample_null["y_mask"])
|
||||
print(type(sample_null["y_mask"]))
|
||||
print(sample_null["y_mask"][0].shape)
|
||||
# print(sample_null["y_mask"])
|
||||
# print(type(sample_null["y_mask"]))
|
||||
# print(sample_null["y_mask"][0].shape)
|
||||
|
||||
print(sample_null["y_feat"])
|
||||
print(type(sample_null["y_feat"]))
|
||||
print(sample_null["y_feat"][0].shape)
|
||||
# print(sample_null["y_feat"])
|
||||
# print(type(sample_null["y_feat"]))
|
||||
# print(sample_null["y_feat"][0].shape)
|
||||
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
@ -359,5 +370,5 @@ class T2VSynthMochiModel:
|
||||
self.dit.to("cpu")
|
||||
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
print("samples", samples.shape, samples.dtype, samples.device)
|
||||
print("samples: ", samples.shape, samples.dtype, samples.device)
|
||||
return samples
|
||||
|
||||
22
nodes.py
22
nodes.py
@ -12,6 +12,15 @@ log = logging.getLogger(__name__)
|
||||
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
||||
from .mochi_preview.vae.model import Decoder
|
||||
|
||||
from contextlib import nullcontext
|
||||
try:
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
is_accelerate_available = True
|
||||
except:
|
||||
is_accelerate_available = False
|
||||
pass
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
@ -40,11 +49,13 @@ class DownloadAndLoadMochiModel:
|
||||
[
|
||||
"mochi_preview_dit_fp8_e4m3fn.safetensors",
|
||||
],
|
||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },
|
||||
),
|
||||
"vae": (
|
||||
[
|
||||
"mochi_preview_vae_bf16.safetensors",
|
||||
],
|
||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
||||
),
|
||||
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
||||
{"default": "fp8_e4m3fn", }
|
||||
@ -101,6 +112,7 @@ class DownloadAndLoadMochiModel:
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
)
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
vae = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
@ -116,10 +128,14 @@ class DownloadAndLoadMochiModel:
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
)
|
||||
decoder_sd = load_torch_file(vae_path)
|
||||
vae.load_state_dict(decoder_sd, strict=True)
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
if is_accelerate_available:
|
||||
for key in vae_sd:
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
|
||||
else:
|
||||
vae.load_state_dict(vae_sd, strict=True)
|
||||
vae.eval().to(torch.bfloat16).to("cpu")
|
||||
del decoder_sd
|
||||
del vae_sd
|
||||
|
||||
return (model, vae,)
|
||||
|
||||
|
||||
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
accelerate
|
||||
einops
|
||||
Loading…
x
Reference in New Issue
Block a user