Add torch compile for main model

This commit is contained in:
kijai 2025-01-22 10:27:21 +02:00
parent bba70eb67c
commit 75ad5e41eb
2 changed files with 57 additions and 4 deletions

View File

@ -152,6 +152,7 @@ class Hunyuan3DDiTPipeline:
offload_device=torch.device('cpu'),
dtype=torch.float16,
use_safetensors=None,
compile_args=None,
**kwargs,
):
# load config
@ -200,6 +201,13 @@ class Hunyuan3DDiTPipeline:
image_processor = instantiate_from_config(config['image_processor'])
scheduler = instantiate_from_config(config['scheduler'])
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
if compile_args["compile_transformer"]:
model = torch.compile(model)
if compile_args["compile_vae"]:
vae = torch.compile(vae)
model_kwargs = dict(
vae=vae,
model=model,

View File

@ -15,6 +15,40 @@ script_directory = os.path.dirname(os.path.abspath(__file__))
from .utils import log, print_memory
class Hy3DTorchCompileSettings:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
"compile_transformer": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
"compile_vae": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
},
}
RETURN_TYPES = ("HY3DCOMPILEARGS",)
RETURN_NAMES = ("torch_compile_args",)
FUNCTION = "loadmodel"
CATEGORY = "HunyuanVideoWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae):
compile_args = {
"backend": backend,
"fullgraph": fullgraph,
"mode": mode,
"dynamic": dynamic,
"dynamo_cache_size_limit": dynamo_cache_size_limit,
"compile_transformer": compile_transformer,
"compile_vae": compile_vae,
}
return (compile_args, )
#region Model loading
class Hy3DModelLoader:
@classmethod
@ -23,6 +57,9 @@ class Hy3DModelLoader:
"required": {
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
},
"optional": {
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
}
}
RETURN_TYPES = ("HY3DMODEL",)
@ -30,13 +67,19 @@ class Hy3DModelLoader:
FUNCTION = "loadmodel"
CATEGORY = "Hunyuan3DWrapper"
def loadmodel(self, model):
def loadmodel(self, model, compile_args=None):
device = mm.get_torch_device()
offload_device=mm.unet_offload_device()
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
model_path = folder_paths.get_full_path("diffusion_models", model)
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device, offload_device=offload_device)
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(
ckpt_path=model_path,
config_path=config_path,
use_safetensors=True,
device=device,
offload_device=offload_device,
compile_args=compile_args)
return (pipe,)
class DownloadAndLoadHy3DDelightModel:
@ -455,7 +498,8 @@ NODE_CLASS_MAPPINGS = {
"DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel,
"Hy3DDelightImage": Hy3DDelightImage,
"Hy3DRenderMultiView": Hy3DRenderMultiView,
"Hy3DBakeFromMultiview": Hy3DBakeFromMultiview
"Hy3DBakeFromMultiview": Hy3DBakeFromMultiview,
"Hy3DTorchCompileSettings": Hy3DTorchCompileSettings
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DModelLoader": "Hy3DModelLoader",
@ -465,5 +509,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",
"Hy3DDelightImage": "Hy3DDelightImage",
"Hy3DRenderMultiView": "Hy3D Render MultiView",
"Hy3DBakeFromMultiview": "Hy3D Bake From Multiview"
"Hy3DBakeFromMultiview": "Hy3D Bake From Multiview",
"Hy3DTorchCompileSettings": "Hy3D Torch Compile Settings"
}