torch compile for vae loader

This commit is contained in:
kijai 2024-10-26 03:24:25 +03:00
parent b7d3fc5e73
commit 0d15c0bd69
2 changed files with 19 additions and 4 deletions

View File

@ -33,7 +33,7 @@ class SafeConv3d(torch.nn.Conv3d):
NOTE: No support for padding along time dimension.
Input must already be padded along time.
"""
@torch.compiler.disable()
def forward(self, input):
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
if memory_count > 2:

View File

@ -1,4 +1,7 @@
import os
# import torch._dynamo
# torch._dynamo.config.suppress_errors = True
import torch
import folder_paths
import comfy.model_management as mm
@ -230,6 +233,9 @@ class MochiVAELoader:
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
},
"optional": {
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
},
}
RETURN_TYPES = ("MOCHIVAE",)
@ -237,7 +243,7 @@ class MochiVAELoader:
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name):
def loadmodel(self, model_name, torch_compile_args=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -264,12 +270,21 @@ class MochiVAELoader:
vae_sd = load_torch_file(vae_path)
if is_accelerate_available:
for key in vae_sd:
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key])
else:
vae.load_state_dict(vae_sd, strict=True)
vae.eval().to(torch.bfloat16).to("cpu")
vae.to(torch.bfloat16).to("cpu")
vae.eval()
del vae_sd
if torch_compile_args is not None:
vae.to(device)
# for i, block in enumerate(vae.blocks):
# if "CausalUpsampleBlock" in str(type(block)):
# print("Compiling block", block)
vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
return (vae,)
class MochiTextEncode: