From 813bbe8f4bbee3b435832202fd3575706285b36c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 24 Oct 2024 21:38:06 +0300 Subject: [PATCH] Add model and vae loader nodes --- .../dit/joint_model/asymm_models_joint.py | 4 +- nodes.py | 92 +++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 4150dc8..bf10a00 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -46,8 +46,10 @@ COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1" COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1" backends = [] -if torch.cuda.get_device_properties(0).major < 7: +if torch.cuda.get_device_properties(0).major <= 7.5: backends.append(SDPBackend.MATH) +if torch.cuda.get_device_properties(0).major >= 9.0: + backends.append(SDPBackend.CUDNN_ATTENTION) else: backends.append(SDPBackend.EFFICIENT_ATTENTION) diff --git a/nodes.py b/nodes.py index ee0e149..1c92a40 100644 --- a/nodes.py +++ b/nodes.py @@ -149,6 +149,94 @@ class DownloadAndLoadMochiModel: return (model, vae,) +class MochiModelLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}), + "precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], {"default": "fp8_e4m3fn"}), + "attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"],), + }, + "optional": { + "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), + }, + } + RETURN_TYPES = ("MOCHIMODEL",) + RETURN_NAMES = ("mochi_model",) + FUNCTION = "loadmodel" + CATEGORY = "MochiWrapper" + + def loadmodel(self, model_name, precision, attention_mode, trigger=None): + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) + + model = T2VSynthMochiModel( + device=device, + offload_device=offload_device, + vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), + dit_checkpoint_path=model_path, + weight_dtype=dtype, + fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, + attention_mode=attention_mode + ) + + return (model, ) + +class MochiVAELoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}), + }, + } + + RETURN_TYPES = ("MOCHIVAE",) + RETURN_NAMES = ("mochi_vae", ) + FUNCTION = "loadmodel" + CATEGORY = "MochiWrapper" + + def loadmodel(self, model_name): + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + vae_path = folder_paths.get_full_path_or_raise("vae", model_name) + + with (init_empty_weights() if is_accelerate_available else nullcontext()): + vae = Decoder( + out_channels=3, + base_channels=128, + channel_multipliers=[1, 2, 4, 6], + temporal_expansions=[1, 2, 3], + spatial_expansions=[2, 2, 2], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + has_attention=[False, False, False, False, False], + padding_mode="replicate", + output_norm=False, + nonlinearity="silu", + output_nonlinearity="silu", + causal=True, + ) + vae_sd = load_torch_file(vae_path) + if is_accelerate_available: + for key in vae_sd: + set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key]) + else: + vae.load_state_dict(vae_sd, strict=True) + vae.eval().to(torch.bfloat16).to("cpu") + del vae_sd + + return (vae,) + class MochiTextEncode: @classmethod def INPUT_TYPES(s): @@ -381,10 +469,14 @@ NODE_CLASS_MAPPINGS = { "MochiSampler": MochiSampler, "MochiDecode": MochiDecode, "MochiTextEncode": MochiTextEncode, + "MochiModelLoader": MochiModelLoader, + "MochiVAELoader": MochiVAELoader, } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadMochiModel": "(Down)load Mochi Model", "MochiSampler": "Mochi Sampler", "MochiDecode": "Mochi Decode", "MochiTextEncode": "Mochi TextEncode", + "MochiModelLoader": "Mochi Model Loader", + "MochiVAELoader": "Mochi VAE Loader", }