torch.compile support
works in Windows with torch 2.5.0 and Triton from https://github.com/woct0rdho/triton-windows
This commit is contained in:
parent
36a4275b3b
commit
25eeab3c4c
@ -42,9 +42,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
SAGEATTN_IS_AVAILABLE = False
|
SAGEATTN_IS_AVAILABLE = False
|
||||||
|
|
||||||
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
|
||||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
|
||||||
|
|
||||||
backends = []
|
backends = []
|
||||||
if torch.cuda.get_device_properties(0).major <= 7.5:
|
if torch.cuda.get_device_properties(0).major <= 7.5:
|
||||||
backends.append(SDPBackend.MATH)
|
backends.append(SDPBackend.MATH)
|
||||||
@ -317,7 +314,6 @@ class AsymmetricAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
|
||||||
class AsymmetricJointBlock(nn.Module):
|
class AsymmetricJointBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -441,7 +437,6 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
#@torch.compile(disable=not COMPILE_FINAL_LAYER)
|
|
||||||
class FinalLayer(nn.Module):
|
class FinalLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
The final layer of DiT.
|
The final layer of DiT.
|
||||||
@ -586,7 +581,6 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||||
|
|
||||||
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@ -18,6 +18,6 @@ class ModulatedRMSNorm(torch.autograd.Function):
|
|||||||
|
|
||||||
return x_modulated.type_as(x)
|
return x_modulated.type_as(x)
|
||||||
|
|
||||||
|
@torch.compiler.disable()
|
||||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||||
return ModulatedRMSNorm.apply(x, scale, eps)
|
return ModulatedRMSNorm.apply(x, scale, eps)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -98,7 +98,8 @@ class T2VSynthMochiModel:
|
|||||||
dit_checkpoint_path: str,
|
dit_checkpoint_path: str,
|
||||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
fp8_fastmode: bool = False,
|
fp8_fastmode: bool = False,
|
||||||
attention_mode: str = "sdpa"
|
attention_mode: str = "sdpa",
|
||||||
|
compile_args: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -157,8 +158,17 @@ class T2VSynthMochiModel:
|
|||||||
from ..fp8_optimization import convert_fp8_linear
|
from ..fp8_optimization import convert_fp8_linear
|
||||||
convert_fp8_linear(model, torch.bfloat16)
|
convert_fp8_linear(model, torch.bfloat16)
|
||||||
|
|
||||||
|
model = model.eval().to(self.device)
|
||||||
|
|
||||||
|
#torch.compile
|
||||||
|
if compile_args is not None:
|
||||||
|
if compile_args["compile_dit"]:
|
||||||
|
for i, block in enumerate(model.blocks):
|
||||||
|
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
||||||
|
if compile_args["compile_final_layer"]:
|
||||||
|
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
||||||
|
|
||||||
self.dit = model
|
self.dit = model
|
||||||
self.dit.eval()
|
|
||||||
|
|
||||||
vae_stats = json.load(open(vae_stats_path))
|
vae_stats = json.load(open(vae_stats_path))
|
||||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||||
|
|||||||
48
nodes.py
48
nodes.py
@ -68,6 +68,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||||
|
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||||
|
|
||||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None):
|
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -121,7 +122,8 @@ class DownloadAndLoadMochiModel:
|
|||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||||
attention_mode=attention_mode
|
attention_mode=attention_mode,
|
||||||
|
compile_args=compile_args
|
||||||
)
|
)
|
||||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
vae = Decoder(
|
vae = Decoder(
|
||||||
@ -161,6 +163,7 @@ class MochiModelLoader:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||||
|
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("MOCHIMODEL",)
|
RETURN_TYPES = ("MOCHIMODEL",)
|
||||||
@ -168,7 +171,7 @@ class MochiModelLoader:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None):
|
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -184,11 +187,42 @@ class MochiModelLoader:
|
|||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||||
attention_mode=attention_mode
|
attention_mode=attention_mode,
|
||||||
|
compile_args=compile_args
|
||||||
)
|
)
|
||||||
|
|
||||||
return (model, )
|
return (model, )
|
||||||
|
|
||||||
|
class MochiTorchCompileSettings:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"backend": (["inductor","cudagraph"], {"default": "inductor"}),
|
||||||
|
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
||||||
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||||
|
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
||||||
|
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
||||||
|
RETURN_NAMES = ("torch_compile_args",)
|
||||||
|
FUNCTION = "loadmodel"
|
||||||
|
CATEGORY = "MochiWrapper"
|
||||||
|
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||||
|
|
||||||
|
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer):
|
||||||
|
|
||||||
|
compile_args = {
|
||||||
|
"backend": backend,
|
||||||
|
"fullgraph": fullgraph,
|
||||||
|
"mode": mode,
|
||||||
|
"compile_dit": compile_dit,
|
||||||
|
"compile_final_layer": compile_final_layer,
|
||||||
|
}
|
||||||
|
|
||||||
|
return (compile_args, )
|
||||||
|
|
||||||
class MochiVAELoader:
|
class MochiVAELoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -522,7 +556,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"MochiTextEncode": MochiTextEncode,
|
"MochiTextEncode": MochiTextEncode,
|
||||||
"MochiModelLoader": MochiModelLoader,
|
"MochiModelLoader": MochiModelLoader,
|
||||||
"MochiVAELoader": MochiVAELoader,
|
"MochiVAELoader": MochiVAELoader,
|
||||||
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling
|
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
|
||||||
|
"MochiTorchCompileSettings": MochiTorchCompileSettings
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||||
@ -531,5 +566,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"MochiTextEncode": "Mochi TextEncode",
|
"MochiTextEncode": "Mochi TextEncode",
|
||||||
"MochiModelLoader": "Mochi Model Loader",
|
"MochiModelLoader": "Mochi Model Loader",
|
||||||
"MochiVAELoader": "Mochi VAE Loader",
|
"MochiVAELoader": "Mochi VAE Loader",
|
||||||
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling"
|
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
|
||||||
|
"MochiTorchCompileSettings": "Mochi Torch Compile Settings"
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user