torch.compile support

works in Windows with torch 2.5.0 and Triton from https://github.com/woct0rdho/triton-windows
This commit is contained in:
kijai 2024-10-25 18:15:30 +03:00
parent 36a4275b3b
commit 25eeab3c4c
4 changed files with 56 additions and 16 deletions

View File

@ -42,9 +42,6 @@ try:
except ImportError:
SAGEATTN_IS_AVAILABLE = False
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
backends = []
if torch.cuda.get_device_properties(0).major <= 7.5:
backends.append(SDPBackend.MATH)
@ -317,7 +314,6 @@ class AsymmetricAttention(nn.Module):
)
return x, y
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
@ -441,7 +437,6 @@ class AsymmetricJointBlock(nn.Module):
return y
#@torch.compile(disable=not COMPILE_FINAL_LAYER)
class FinalLayer(nn.Module):
"""
The final layer of DiT.
@ -586,7 +581,6 @@ class AsymmDiTJoint(nn.Module):
"""
return self.x_embedder(x) # Convert BcTHW to BCN
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
def prepare(
self,
x: torch.Tensor,

View File

@ -18,6 +18,6 @@ class ModulatedRMSNorm(torch.autograd.Function):
return x_modulated.type_as(x)
@torch.compiler.disable()
def modulated_rmsnorm(x, scale, eps=1e-6):
return ModulatedRMSNorm.apply(x, scale, eps)

View File

@ -1,5 +1,5 @@
import json
from typing import Dict, List
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
@ -98,7 +98,8 @@ class T2VSynthMochiModel:
dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
fp8_fastmode: bool = False,
attention_mode: str = "sdpa"
attention_mode: str = "sdpa",
compile_args: Optional[Dict] = None,
):
super().__init__()
self.device = device
@ -157,8 +158,17 @@ class T2VSynthMochiModel:
from ..fp8_optimization import convert_fp8_linear
convert_fp8_linear(model, torch.bfloat16)
model = model.eval().to(self.device)
#torch.compile
if compile_args is not None:
if compile_args["compile_dit"]:
for i, block in enumerate(model.blocks):
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
if compile_args["compile_final_layer"]:
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
self.dit = model
self.dit.eval()
vae_stats = json.load(open(vae_stats_path))
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)

View File

@ -68,6 +68,7 @@ class DownloadAndLoadMochiModel:
},
"optional": {
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
},
}
@ -77,7 +78,7 @@ class DownloadAndLoadMochiModel:
CATEGORY = "MochiWrapper"
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
def loadmodel(self, model, vae, precision, attention_mode, trigger=None):
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -121,7 +122,8 @@ class DownloadAndLoadMochiModel:
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode
attention_mode=attention_mode,
compile_args=compile_args
)
with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder(
@ -161,6 +163,7 @@ class MochiModelLoader:
},
"optional": {
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
},
}
RETURN_TYPES = ("MOCHIMODEL",)
@ -168,7 +171,7 @@ class MochiModelLoader:
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, precision, attention_mode, trigger=None):
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -184,11 +187,42 @@ class MochiModelLoader:
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode
attention_mode=attention_mode,
compile_args=compile_args
)
return (model, )
class MochiTorchCompileSettings:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"backend": (["inductor","cudagraph"], {"default": "inductor"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
},
}
RETURN_TYPES = ("MOCHICOMPILEARGS",)
RETURN_NAMES = ("torch_compile_args",)
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer):
compile_args = {
"backend": backend,
"fullgraph": fullgraph,
"mode": mode,
"compile_dit": compile_dit,
"compile_final_layer": compile_final_layer,
}
return (compile_args, )
class MochiVAELoader:
@classmethod
def INPUT_TYPES(s):
@ -522,7 +556,8 @@ NODE_CLASS_MAPPINGS = {
"MochiTextEncode": MochiTextEncode,
"MochiModelLoader": MochiModelLoader,
"MochiVAELoader": MochiVAELoader,
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
"MochiTorchCompileSettings": MochiTorchCompileSettings
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
@ -531,5 +566,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MochiTextEncode": "Mochi TextEncode",
"MochiModelLoader": "Mochi Model Loader",
"MochiVAELoader": "Mochi VAE Loader",
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling"
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
"MochiTorchCompileSettings": "Mochi Torch Compile Settings"
}