Add model and vae loader nodes

This commit is contained in:
kijai 2024-10-24 21:38:06 +03:00
parent 813d6aa92f
commit 813bbe8f4b
2 changed files with 95 additions and 1 deletions

View File

@ -46,8 +46,10 @@ COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
backends = []
if torch.cuda.get_device_properties(0).major < 7:
if torch.cuda.get_device_properties(0).major <= 7.5:
backends.append(SDPBackend.MATH)
if torch.cuda.get_device_properties(0).major >= 9.0:
backends.append(SDPBackend.CUDNN_ATTENTION)
else:
backends.append(SDPBackend.EFFICIENT_ATTENTION)

View File

@ -149,6 +149,94 @@ class DownloadAndLoadMochiModel:
return (model, vae,)
class MochiModelLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}),
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], {"default": "fp8_e4m3fn"}),
"attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"],),
},
"optional": {
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
},
}
RETURN_TYPES = ("MOCHIMODEL",)
RETURN_NAMES = ("mochi_model",)
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, precision, attention_mode, trigger=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
model = T2VSynthMochiModel(
device=device,
offload_device=offload_device,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode
)
return (model, )
class MochiVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
},
}
RETURN_TYPES = ("MOCHIVAE",)
RETURN_NAMES = ("mochi_vae", )
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder(
out_channels=3,
base_channels=128,
channel_multipliers=[1, 2, 4, 6],
temporal_expansions=[1, 2, 3],
spatial_expansions=[2, 2, 2],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False,
nonlinearity="silu",
output_nonlinearity="silu",
causal=True,
)
vae_sd = load_torch_file(vae_path)
if is_accelerate_available:
for key in vae_sd:
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
else:
vae.load_state_dict(vae_sd, strict=True)
vae.eval().to(torch.bfloat16).to("cpu")
del vae_sd
return (vae,)
class MochiTextEncode:
@classmethod
def INPUT_TYPES(s):
@ -381,10 +469,14 @@ NODE_CLASS_MAPPINGS = {
"MochiSampler": MochiSampler,
"MochiDecode": MochiDecode,
"MochiTextEncode": MochiTextEncode,
"MochiModelLoader": MochiModelLoader,
"MochiVAELoader": MochiVAELoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
"MochiSampler": "Mochi Sampler",
"MochiDecode": "Mochi Decode",
"MochiTextEncode": "Mochi TextEncode",
"MochiModelLoader": "Mochi Model Loader",
"MochiVAELoader": "Mochi VAE Loader",
}