mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-13 14:54:35 +08:00
Add torch compile for main model
This commit is contained in:
parent
bba70eb67c
commit
75ad5e41eb
@ -152,6 +152,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
offload_device=torch.device('cpu'),
|
offload_device=torch.device('cpu'),
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
use_safetensors=None,
|
use_safetensors=None,
|
||||||
|
compile_args=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# load config
|
# load config
|
||||||
@ -200,6 +201,13 @@ class Hunyuan3DDiTPipeline:
|
|||||||
image_processor = instantiate_from_config(config['image_processor'])
|
image_processor = instantiate_from_config(config['image_processor'])
|
||||||
scheduler = instantiate_from_config(config['scheduler'])
|
scheduler = instantiate_from_config(config['scheduler'])
|
||||||
|
|
||||||
|
if compile_args is not None:
|
||||||
|
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||||
|
if compile_args["compile_transformer"]:
|
||||||
|
model = torch.compile(model)
|
||||||
|
if compile_args["compile_vae"]:
|
||||||
|
vae = torch.compile(vae)
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
53
nodes.py
53
nodes.py
@ -15,6 +15,40 @@ script_directory = os.path.dirname(os.path.abspath(__file__))
|
|||||||
|
|
||||||
from .utils import log, print_memory
|
from .utils import log, print_memory
|
||||||
|
|
||||||
|
class Hy3DTorchCompileSettings:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
|
||||||
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
||||||
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||||
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||||
|
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
||||||
|
"compile_transformer": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
|
||||||
|
"compile_vae": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("HY3DCOMPILEARGS",)
|
||||||
|
RETURN_NAMES = ("torch_compile_args",)
|
||||||
|
FUNCTION = "loadmodel"
|
||||||
|
CATEGORY = "HunyuanVideoWrapper"
|
||||||
|
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||||
|
|
||||||
|
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae):
|
||||||
|
|
||||||
|
compile_args = {
|
||||||
|
"backend": backend,
|
||||||
|
"fullgraph": fullgraph,
|
||||||
|
"mode": mode,
|
||||||
|
"dynamic": dynamic,
|
||||||
|
"dynamo_cache_size_limit": dynamo_cache_size_limit,
|
||||||
|
"compile_transformer": compile_transformer,
|
||||||
|
"compile_vae": compile_vae,
|
||||||
|
}
|
||||||
|
|
||||||
|
return (compile_args, )
|
||||||
|
|
||||||
#region Model loading
|
#region Model loading
|
||||||
class Hy3DModelLoader:
|
class Hy3DModelLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -23,6 +57,9 @@ class Hy3DModelLoader:
|
|||||||
"required": {
|
"required": {
|
||||||
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
|
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
|
||||||
},
|
},
|
||||||
|
"optional": {
|
||||||
|
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("HY3DMODEL",)
|
RETURN_TYPES = ("HY3DMODEL",)
|
||||||
@ -30,13 +67,19 @@ class Hy3DModelLoader:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "Hunyuan3DWrapper"
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model):
|
def loadmodel(self, model, compile_args=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device=mm.unet_offload_device()
|
offload_device=mm.unet_offload_device()
|
||||||
|
|
||||||
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
|
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
|
||||||
model_path = folder_paths.get_full_path("diffusion_models", model)
|
model_path = folder_paths.get_full_path("diffusion_models", model)
|
||||||
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device, offload_device=offload_device)
|
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(
|
||||||
|
ckpt_path=model_path,
|
||||||
|
config_path=config_path,
|
||||||
|
use_safetensors=True,
|
||||||
|
device=device,
|
||||||
|
offload_device=offload_device,
|
||||||
|
compile_args=compile_args)
|
||||||
return (pipe,)
|
return (pipe,)
|
||||||
|
|
||||||
class DownloadAndLoadHy3DDelightModel:
|
class DownloadAndLoadHy3DDelightModel:
|
||||||
@ -455,7 +498,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel,
|
"DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel,
|
||||||
"Hy3DDelightImage": Hy3DDelightImage,
|
"Hy3DDelightImage": Hy3DDelightImage,
|
||||||
"Hy3DRenderMultiView": Hy3DRenderMultiView,
|
"Hy3DRenderMultiView": Hy3DRenderMultiView,
|
||||||
"Hy3DBakeFromMultiview": Hy3DBakeFromMultiview
|
"Hy3DBakeFromMultiview": Hy3DBakeFromMultiview,
|
||||||
|
"Hy3DTorchCompileSettings": Hy3DTorchCompileSettings
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"Hy3DModelLoader": "Hy3DModelLoader",
|
"Hy3DModelLoader": "Hy3DModelLoader",
|
||||||
@ -465,5 +509,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",
|
"DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",
|
||||||
"Hy3DDelightImage": "Hy3DDelightImage",
|
"Hy3DDelightImage": "Hy3DDelightImage",
|
||||||
"Hy3DRenderMultiView": "Hy3D Render MultiView",
|
"Hy3DRenderMultiView": "Hy3D Render MultiView",
|
||||||
"Hy3DBakeFromMultiview": "Hy3D Bake From Multiview"
|
"Hy3DBakeFromMultiview": "Hy3D Bake From Multiview",
|
||||||
|
"Hy3DTorchCompileSettings": "Hy3D Torch Compile Settings"
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user