Add accelerate
This commit is contained in:
parent
34e029bacc
commit
db87f8e608
@ -13,7 +13,16 @@ from torch import nn
|
|||||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||||
from .utils import Timer
|
from .utils import Timer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
|
|
||||||
|
from contextlib import nullcontext
|
||||||
|
try:
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
is_accelerate_available = True
|
||||||
|
except:
|
||||||
|
is_accelerate_available = False
|
||||||
|
pass
|
||||||
|
|
||||||
MAX_T5_TOKEN_LENGTH = 256
|
MAX_T5_TOKEN_LENGTH = 256
|
||||||
|
|
||||||
@ -138,29 +147,33 @@ class T2VSynthMochiModel:
|
|||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
)
|
)
|
||||||
with t("dit_load_checkpoint"):
|
with t("dit_load_checkpoint"):
|
||||||
|
|
||||||
model.load_state_dict(load_file(dit_checkpoint_path))
|
|
||||||
|
|
||||||
#with t("fsdp_dit"):
|
|
||||||
self.dit = model
|
|
||||||
self.dit.eval()
|
|
||||||
for name, param in self.dit.named_parameters():
|
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
|
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||||
|
if is_accelerate_available:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
|
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
||||||
|
else:
|
||||||
|
model.load_state_dict(dit_sd)
|
||||||
|
for name, param in self.dit.named_parameters():
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
param.data = param.data.to(weight_dtype)
|
param.data = param.data.to(weight_dtype)
|
||||||
else:
|
else:
|
||||||
param.data = param.data.to(torch.bfloat16)
|
param.data = param.data.to(torch.bfloat16)
|
||||||
|
|
||||||
|
self.dit = model
|
||||||
|
self.dit.eval()
|
||||||
|
|
||||||
vae_stats = json.load(open(vae_stats_path))
|
vae_stats = json.load(open(vae_stats_path))
|
||||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||||
|
|
||||||
t.print_stats()
|
#t.print_stats()
|
||||||
|
|
||||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||||
B = len(prompts)
|
B = len(prompts)
|
||||||
print(f"Getting conditioning for {B} prompts")
|
|
||||||
assert (
|
assert (
|
||||||
0 <= zero_last_n_prompts <= B
|
0 <= zero_last_n_prompts <= B
|
||||||
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
|
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
|
||||||
@ -198,8 +211,6 @@ class T2VSynthMochiModel:
|
|||||||
caption_input_ids_t5, caption_attention_mask_t5
|
caption_input_ids_t5, caption_attention_mask_t5
|
||||||
).last_hidden_state.detach().to(torch.float32)
|
).last_hidden_state.detach().to(torch.float32)
|
||||||
)
|
)
|
||||||
print(y_feat.shape)
|
|
||||||
print(y_feat[0])
|
|
||||||
self.t5_enc.to("cpu")
|
self.t5_enc.to("cpu")
|
||||||
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
||||||
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
||||||
@ -298,13 +309,13 @@ class T2VSynthMochiModel:
|
|||||||
# print(type(sample["y_feat"]))
|
# print(type(sample["y_feat"]))
|
||||||
# print(sample["y_feat"][0].shape)
|
# print(sample["y_feat"][0].shape)
|
||||||
|
|
||||||
print(sample_null["y_mask"])
|
# print(sample_null["y_mask"])
|
||||||
print(type(sample_null["y_mask"]))
|
# print(type(sample_null["y_mask"]))
|
||||||
print(sample_null["y_mask"][0].shape)
|
# print(sample_null["y_mask"][0].shape)
|
||||||
|
|
||||||
print(sample_null["y_feat"])
|
# print(sample_null["y_feat"])
|
||||||
print(type(sample_null["y_feat"]))
|
# print(type(sample_null["y_feat"]))
|
||||||
print(sample_null["y_feat"][0].shape)
|
# print(sample_null["y_feat"][0].shape)
|
||||||
|
|
||||||
sample["packed_indices"] = self.get_packed_indices(
|
sample["packed_indices"] = self.get_packed_indices(
|
||||||
sample["y_mask"], **latent_dims
|
sample["y_mask"], **latent_dims
|
||||||
@ -359,5 +370,5 @@ class T2VSynthMochiModel:
|
|||||||
self.dit.to("cpu")
|
self.dit.to("cpu")
|
||||||
|
|
||||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||||
print("samples", samples.shape, samples.dtype, samples.device)
|
print("samples: ", samples.shape, samples.dtype, samples.device)
|
||||||
return samples
|
return samples
|
||||||
|
|||||||
22
nodes.py
22
nodes.py
@ -12,6 +12,15 @@ log = logging.getLogger(__name__)
|
|||||||
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
||||||
from .mochi_preview.vae.model import Decoder
|
from .mochi_preview.vae.model import Decoder
|
||||||
|
|
||||||
|
from contextlib import nullcontext
|
||||||
|
try:
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
is_accelerate_available = True
|
||||||
|
except:
|
||||||
|
is_accelerate_available = False
|
||||||
|
pass
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||||
@ -40,11 +49,13 @@ class DownloadAndLoadMochiModel:
|
|||||||
[
|
[
|
||||||
"mochi_preview_dit_fp8_e4m3fn.safetensors",
|
"mochi_preview_dit_fp8_e4m3fn.safetensors",
|
||||||
],
|
],
|
||||||
|
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },
|
||||||
),
|
),
|
||||||
"vae": (
|
"vae": (
|
||||||
[
|
[
|
||||||
"mochi_preview_vae_bf16.safetensors",
|
"mochi_preview_vae_bf16.safetensors",
|
||||||
],
|
],
|
||||||
|
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
||||||
),
|
),
|
||||||
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
||||||
{"default": "fp8_e4m3fn", }
|
{"default": "fp8_e4m3fn", }
|
||||||
@ -101,6 +112,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
)
|
)
|
||||||
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
vae = Decoder(
|
vae = Decoder(
|
||||||
out_channels=3,
|
out_channels=3,
|
||||||
base_channels=128,
|
base_channels=128,
|
||||||
@ -116,10 +128,14 @@ class DownloadAndLoadMochiModel:
|
|||||||
output_nonlinearity="silu",
|
output_nonlinearity="silu",
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
decoder_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
vae.load_state_dict(decoder_sd, strict=True)
|
if is_accelerate_available:
|
||||||
|
for key in vae_sd:
|
||||||
|
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
|
||||||
|
else:
|
||||||
|
vae.load_state_dict(vae_sd, strict=True)
|
||||||
vae.eval().to(torch.bfloat16).to("cpu")
|
vae.eval().to(torch.bfloat16).to("cpu")
|
||||||
del decoder_sd
|
del vae_sd
|
||||||
|
|
||||||
return (model, vae,)
|
return (model, vae,)
|
||||||
|
|
||||||
|
|||||||
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
accelerate
|
||||||
|
einops
|
||||||
Loading…
x
Reference in New Issue
Block a user