torch compile for vae loader
This commit is contained in:
parent
b7d3fc5e73
commit
0d15c0bd69
@ -33,7 +33,7 @@ class SafeConv3d(torch.nn.Conv3d):
|
||||
NOTE: No support for padding along time dimension.
|
||||
Input must already be padded along time.
|
||||
"""
|
||||
|
||||
@torch.compiler.disable()
|
||||
def forward(self, input):
|
||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||
if memory_count > 2:
|
||||
|
||||
21
nodes.py
21
nodes.py
@ -1,4 +1,7 @@
|
||||
import os
|
||||
# import torch._dynamo
|
||||
# torch._dynamo.config.suppress_errors = True
|
||||
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
@ -230,6 +233,9 @@ class MochiVAELoader:
|
||||
"required": {
|
||||
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
|
||||
},
|
||||
"optional": {
|
||||
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MOCHIVAE",)
|
||||
@ -237,7 +243,7 @@ class MochiVAELoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def loadmodel(self, model_name):
|
||||
def loadmodel(self, model_name, torch_compile_args=None):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -264,12 +270,21 @@ class MochiVAELoader:
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
if is_accelerate_available:
|
||||
for key in vae_sd:
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key])
|
||||
else:
|
||||
vae.load_state_dict(vae_sd, strict=True)
|
||||
vae.eval().to(torch.bfloat16).to("cpu")
|
||||
vae.to(torch.bfloat16).to("cpu")
|
||||
vae.eval()
|
||||
del vae_sd
|
||||
|
||||
if torch_compile_args is not None:
|
||||
vae.to(device)
|
||||
# for i, block in enumerate(vae.blocks):
|
||||
# if "CausalUpsampleBlock" in str(type(block)):
|
||||
# print("Compiling block", block)
|
||||
vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
||||
|
||||
|
||||
return (vae,)
|
||||
|
||||
class MochiTextEncode:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user