mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-08 20:34:28 +08:00
Add torch compile for main model
This commit is contained in:
parent
bba70eb67c
commit
75ad5e41eb
@ -152,6 +152,7 @@ class Hunyuan3DDiTPipeline:
|
||||
offload_device=torch.device('cpu'),
|
||||
dtype=torch.float16,
|
||||
use_safetensors=None,
|
||||
compile_args=None,
|
||||
**kwargs,
|
||||
):
|
||||
# load config
|
||||
@ -200,6 +201,13 @@ class Hunyuan3DDiTPipeline:
|
||||
image_processor = instantiate_from_config(config['image_processor'])
|
||||
scheduler = instantiate_from_config(config['scheduler'])
|
||||
|
||||
if compile_args is not None:
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
if compile_args["compile_transformer"]:
|
||||
model = torch.compile(model)
|
||||
if compile_args["compile_vae"]:
|
||||
vae = torch.compile(vae)
|
||||
|
||||
model_kwargs = dict(
|
||||
vae=vae,
|
||||
model=model,
|
||||
|
||||
53
nodes.py
53
nodes.py
@ -15,6 +15,40 @@ script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
from .utils import log, print_memory
|
||||
|
||||
class Hy3DTorchCompileSettings:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
|
||||
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
||||
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
||||
"compile_transformer": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
|
||||
"compile_vae": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("HY3DCOMPILEARGS",)
|
||||
RETURN_NAMES = ("torch_compile_args",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "HunyuanVideoWrapper"
|
||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||
|
||||
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae):
|
||||
|
||||
compile_args = {
|
||||
"backend": backend,
|
||||
"fullgraph": fullgraph,
|
||||
"mode": mode,
|
||||
"dynamic": dynamic,
|
||||
"dynamo_cache_size_limit": dynamo_cache_size_limit,
|
||||
"compile_transformer": compile_transformer,
|
||||
"compile_vae": compile_vae,
|
||||
}
|
||||
|
||||
return (compile_args, )
|
||||
|
||||
#region Model loading
|
||||
class Hy3DModelLoader:
|
||||
@classmethod
|
||||
@ -23,6 +57,9 @@ class Hy3DModelLoader:
|
||||
"required": {
|
||||
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
|
||||
},
|
||||
"optional": {
|
||||
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("HY3DMODEL",)
|
||||
@ -30,13 +67,19 @@ class Hy3DModelLoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def loadmodel(self, model):
|
||||
def loadmodel(self, model, compile_args=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
|
||||
model_path = folder_paths.get_full_path("diffusion_models", model)
|
||||
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device, offload_device=offload_device)
|
||||
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(
|
||||
ckpt_path=model_path,
|
||||
config_path=config_path,
|
||||
use_safetensors=True,
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
compile_args=compile_args)
|
||||
return (pipe,)
|
||||
|
||||
class DownloadAndLoadHy3DDelightModel:
|
||||
@ -455,7 +498,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel,
|
||||
"Hy3DDelightImage": Hy3DDelightImage,
|
||||
"Hy3DRenderMultiView": Hy3DRenderMultiView,
|
||||
"Hy3DBakeFromMultiview": Hy3DBakeFromMultiview
|
||||
"Hy3DBakeFromMultiview": Hy3DBakeFromMultiview,
|
||||
"Hy3DTorchCompileSettings": Hy3DTorchCompileSettings
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Hy3DModelLoader": "Hy3DModelLoader",
|
||||
@ -465,5 +509,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",
|
||||
"Hy3DDelightImage": "Hy3DDelightImage",
|
||||
"Hy3DRenderMultiView": "Hy3D Render MultiView",
|
||||
"Hy3DBakeFromMultiview": "Hy3D Bake From Multiview"
|
||||
"Hy3DBakeFromMultiview": "Hy3D Bake From Multiview",
|
||||
"Hy3DTorchCompileSettings": "Hy3D Torch Compile Settings"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user