Add TorchCompileModelFluxAdvancedV2

Utilizing new ComfyUI core functions that also support LoRAs without additional patches
This commit is contained in:
kijai 2025-05-21 17:17:41 +03:00
parent 44565e9bff
commit 16f60e53e5
2 changed files with 59 additions and 0 deletions

View File

@ -174,6 +174,7 @@ NODE_CONFIG = {
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
"DiffusionModelLoaderKJ": {"class": DiffusionModelLoaderKJ, "name": "Diffusion Model Loader KJ"},
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
"TorchCompileModelFluxAdvancedV2": {"class": TorchCompileModelFluxAdvancedV2, "name": "TorchCompileModelFluxAdvancedV2"},
"TorchCompileModelHyVideo": {"class": TorchCompileModelHyVideo, "name": "TorchCompileModelHyVideo"},
"TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"},
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},

View File

@ -22,6 +22,7 @@ class BaseLoaderKJ:
original_linear = None
cublas_patched = False
@torch.compiler.disable()
def _patch_modules(self, patch_cublaslinear, sage_attention):
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
@ -146,8 +147,10 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
def patch(self, model, sage_attention):
model_clone = model.clone()
@torch.compiler.disable()
def patch_attention_enable(model):
self._patch_modules(False, sage_attention)
@torch.compiler.disable()
def patch_attention_disable(model):
self._patch_modules(False, "disabled")
@ -574,6 +577,61 @@ class TorchCompileModelFluxAdvanced:
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
class TorchCompileModelFluxAdvancedV2:
def __init__(self):
self._compiled = False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
"single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
},
"optional": {
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/torchcompile"
EXPERIMENTAL = True
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic, dynamo_cache_size_limit):
from comfy_api.torch_helpers import set_torch_compile_wrapper
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
compile_key_list = []
try:
if double_blocks:
for block in diffusion_model.double_blocks:
compile_key_list.append(block)
if single_blocks:
for block in diffusion_model.single_blocks:
compile_key_list.append(block)
set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph)
except:
raise RuntimeError("Failed to compile model")
return (m, )
# rest of the layers that are not patched
# diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
class TorchCompileModelHyVideo:
def __init__(self):
self._compiled = False