torch compile for vae loader
This commit is contained in:
parent
b7d3fc5e73
commit
0d15c0bd69
@ -33,7 +33,7 @@ class SafeConv3d(torch.nn.Conv3d):
|
|||||||
NOTE: No support for padding along time dimension.
|
NOTE: No support for padding along time dimension.
|
||||||
Input must already be padded along time.
|
Input must already be padded along time.
|
||||||
"""
|
"""
|
||||||
|
@torch.compiler.disable()
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||||
if memory_count > 2:
|
if memory_count > 2:
|
||||||
|
|||||||
21
nodes.py
21
nodes.py
@ -1,4 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
# import torch._dynamo
|
||||||
|
# torch._dynamo.config.suppress_errors = True
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
@ -230,6 +233,9 @@ class MochiVAELoader:
|
|||||||
"required": {
|
"required": {
|
||||||
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
|
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
|
||||||
},
|
},
|
||||||
|
"optional": {
|
||||||
|
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("MOCHIVAE",)
|
RETURN_TYPES = ("MOCHIVAE",)
|
||||||
@ -237,7 +243,7 @@ class MochiVAELoader:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model_name):
|
def loadmodel(self, model_name, torch_compile_args=None):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -264,12 +270,21 @@ class MochiVAELoader:
|
|||||||
vae_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for key in vae_sd:
|
for key in vae_sd:
|
||||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
|
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key])
|
||||||
else:
|
else:
|
||||||
vae.load_state_dict(vae_sd, strict=True)
|
vae.load_state_dict(vae_sd, strict=True)
|
||||||
vae.eval().to(torch.bfloat16).to("cpu")
|
vae.to(torch.bfloat16).to("cpu")
|
||||||
|
vae.eval()
|
||||||
del vae_sd
|
del vae_sd
|
||||||
|
|
||||||
|
if torch_compile_args is not None:
|
||||||
|
vae.to(device)
|
||||||
|
# for i, block in enumerate(vae.blocks):
|
||||||
|
# if "CausalUpsampleBlock" in str(type(block)):
|
||||||
|
# print("Compiling block", block)
|
||||||
|
vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
||||||
|
|
||||||
|
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class MochiTextEncode:
|
class MochiTextEncode:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user