torch.compile support
works in Windows with torch 2.5.0 and Triton from https://github.com/woct0rdho/triton-windows
This commit is contained in:
parent
36a4275b3b
commit
25eeab3c4c
@ -42,9 +42,6 @@ try:
|
||||
except ImportError:
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
|
||||
backends = []
|
||||
if torch.cuda.get_device_properties(0).major <= 7.5:
|
||||
backends.append(SDPBackend.MATH)
|
||||
@ -317,7 +314,6 @@ class AsymmetricAttention(nn.Module):
|
||||
)
|
||||
return x, y
|
||||
|
||||
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -441,7 +437,6 @@ class AsymmetricJointBlock(nn.Module):
|
||||
return y
|
||||
|
||||
|
||||
#@torch.compile(disable=not COMPILE_FINAL_LAYER)
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
@ -586,7 +581,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
"""
|
||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||
|
||||
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
||||
def prepare(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -18,6 +18,6 @@ class ModulatedRMSNorm(torch.autograd.Function):
|
||||
|
||||
return x_modulated.type_as(x)
|
||||
|
||||
|
||||
@torch.compiler.disable()
|
||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||
return ModulatedRMSNorm.apply(x, scale, eps)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -98,7 +98,8 @@ class T2VSynthMochiModel:
|
||||
dit_checkpoint_path: str,
|
||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
fp8_fastmode: bool = False,
|
||||
attention_mode: str = "sdpa"
|
||||
attention_mode: str = "sdpa",
|
||||
compile_args: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@ -157,8 +158,17 @@ class T2VSynthMochiModel:
|
||||
from ..fp8_optimization import convert_fp8_linear
|
||||
convert_fp8_linear(model, torch.bfloat16)
|
||||
|
||||
model = model.eval().to(self.device)
|
||||
|
||||
#torch.compile
|
||||
if compile_args is not None:
|
||||
if compile_args["compile_dit"]:
|
||||
for i, block in enumerate(model.blocks):
|
||||
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
||||
if compile_args["compile_final_layer"]:
|
||||
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
||||
|
||||
self.dit = model
|
||||
self.dit.eval()
|
||||
|
||||
vae_stats = json.load(open(vae_stats_path))
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
|
||||
48
nodes.py
48
nodes.py
@ -68,6 +68,7 @@ class DownloadAndLoadMochiModel:
|
||||
},
|
||||
"optional": {
|
||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -77,7 +78,7 @@ class DownloadAndLoadMochiModel:
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||
|
||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None):
|
||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -121,7 +122,8 @@ class DownloadAndLoadMochiModel:
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode
|
||||
attention_mode=attention_mode,
|
||||
compile_args=compile_args
|
||||
)
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
vae = Decoder(
|
||||
@ -161,6 +163,7 @@ class MochiModelLoader:
|
||||
},
|
||||
"optional": {
|
||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("MOCHIMODEL",)
|
||||
@ -168,7 +171,7 @@ class MochiModelLoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None):
|
||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -184,11 +187,42 @@ class MochiModelLoader:
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode
|
||||
attention_mode=attention_mode,
|
||||
compile_args=compile_args
|
||||
)
|
||||
|
||||
return (model, )
|
||||
|
||||
class MochiTorchCompileSettings:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"backend": (["inductor","cudagraph"], {"default": "inductor"}),
|
||||
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
||||
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
||||
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
||||
RETURN_NAMES = ("torch_compile_args",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||
|
||||
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer):
|
||||
|
||||
compile_args = {
|
||||
"backend": backend,
|
||||
"fullgraph": fullgraph,
|
||||
"mode": mode,
|
||||
"compile_dit": compile_dit,
|
||||
"compile_final_layer": compile_final_layer,
|
||||
}
|
||||
|
||||
return (compile_args, )
|
||||
|
||||
class MochiVAELoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -522,7 +556,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"MochiTextEncode": MochiTextEncode,
|
||||
"MochiModelLoader": MochiModelLoader,
|
||||
"MochiVAELoader": MochiVAELoader,
|
||||
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling
|
||||
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
|
||||
"MochiTorchCompileSettings": MochiTorchCompileSettings
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||
@ -531,5 +566,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MochiTextEncode": "Mochi TextEncode",
|
||||
"MochiModelLoader": "Mochi Model Loader",
|
||||
"MochiVAELoader": "Mochi VAE Loader",
|
||||
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling"
|
||||
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
|
||||
"MochiTorchCompileSettings": "Mochi Torch Compile Settings"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user