From 0d15c0bd69d54fd1ac4727172f2e4648c1ebd307 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 26 Oct 2024 03:24:25 +0300 Subject: [PATCH] torch compile for vae loader --- mochi_preview/vae/model.py | 2 +- nodes.py | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 1263271..2eef51a 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -33,7 +33,7 @@ class SafeConv3d(torch.nn.Conv3d): NOTE: No support for padding along time dimension. Input must already be padded along time. """ - + @torch.compiler.disable() def forward(self, input): memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 if memory_count > 2: diff --git a/nodes.py b/nodes.py index d9f9dc6..e1c342e 100644 --- a/nodes.py +++ b/nodes.py @@ -1,4 +1,7 @@ import os +# import torch._dynamo +# torch._dynamo.config.suppress_errors = True + import torch import folder_paths import comfy.model_management as mm @@ -230,6 +233,9 @@ class MochiVAELoader: "required": { "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}), }, + "optional": { + "torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), + }, } RETURN_TYPES = ("MOCHIVAE",) @@ -237,7 +243,7 @@ class MochiVAELoader: FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - def loadmodel(self, model_name): + def loadmodel(self, model_name, torch_compile_args=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -264,12 +270,21 @@ class MochiVAELoader: vae_sd = load_torch_file(vae_path) if is_accelerate_available: for key in vae_sd: - set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key]) + set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key]) else: vae.load_state_dict(vae_sd, strict=True) - vae.eval().to(torch.bfloat16).to("cpu") + vae.to(torch.bfloat16).to("cpu") + vae.eval() del vae_sd + if torch_compile_args is not None: + vae.to(device) + # for i, block in enumerate(vae.blocks): + # if "CausalUpsampleBlock" in str(type(block)): + # print("Compiling block", block) + vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"]) + + return (vae,) class MochiTextEncode: