Add model and vae loader nodes
This commit is contained in:
parent
813d6aa92f
commit
813bbe8f4b
@ -46,8 +46,10 @@ COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
|
||||
backends = []
|
||||
if torch.cuda.get_device_properties(0).major < 7:
|
||||
if torch.cuda.get_device_properties(0).major <= 7.5:
|
||||
backends.append(SDPBackend.MATH)
|
||||
if torch.cuda.get_device_properties(0).major >= 9.0:
|
||||
backends.append(SDPBackend.CUDNN_ATTENTION)
|
||||
else:
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
|
||||
|
||||
92
nodes.py
92
nodes.py
@ -149,6 +149,94 @@ class DownloadAndLoadMochiModel:
|
||||
|
||||
return (model, vae,)
|
||||
|
||||
class MochiModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}),
|
||||
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], {"default": "fp8_e4m3fn"}),
|
||||
"attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"],),
|
||||
},
|
||||
"optional": {
|
||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("MOCHIMODEL",)
|
||||
RETURN_NAMES = ("mochi_model",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
|
||||
|
||||
model = T2VSynthMochiModel(
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode
|
||||
)
|
||||
|
||||
return (model, )
|
||||
|
||||
class MochiVAELoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MOCHIVAE",)
|
||||
RETURN_NAMES = ("mochi_vae", )
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def loadmodel(self, model_name):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
|
||||
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
vae = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
temporal_expansions=[1, 2, 3],
|
||||
spatial_expansions=[2, 2, 2],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
has_attention=[False, False, False, False, False],
|
||||
padding_mode="replicate",
|
||||
output_norm=False,
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
)
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
if is_accelerate_available:
|
||||
for key in vae_sd:
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
|
||||
else:
|
||||
vae.load_state_dict(vae_sd, strict=True)
|
||||
vae.eval().to(torch.bfloat16).to("cpu")
|
||||
del vae_sd
|
||||
|
||||
return (vae,)
|
||||
|
||||
class MochiTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -381,10 +469,14 @@ NODE_CLASS_MAPPINGS = {
|
||||
"MochiSampler": MochiSampler,
|
||||
"MochiDecode": MochiDecode,
|
||||
"MochiTextEncode": MochiTextEncode,
|
||||
"MochiModelLoader": MochiModelLoader,
|
||||
"MochiVAELoader": MochiVAELoader,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||
"MochiSampler": "Mochi Sampler",
|
||||
"MochiDecode": "Mochi Decode",
|
||||
"MochiTextEncode": "Mochi TextEncode",
|
||||
"MochiModelLoader": "Mochi Model Loader",
|
||||
"MochiVAELoader": "Mochi VAE Loader",
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user