From 25eeab3c4cb5f92f4fdfd0dc05ce073542bbdad0 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 18:15:30 +0300 Subject: [PATCH] torch.compile support works in Windows with torch 2.5.0 and Triton from https://github.com/woct0rdho/triton-windows --- .../dit/joint_model/asymm_models_joint.py | 6 --- mochi_preview/dit/joint_model/mod_rmsnorm.py | 2 +- mochi_preview/t2v_synth_mochi.py | 16 +++++-- nodes.py | 48 ++++++++++++++++--- 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index bf10a00..1e4c207 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -42,9 +42,6 @@ try: except ImportError: SAGEATTN_IS_AVAILABLE = False -COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1" -COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1" - backends = [] if torch.cuda.get_device_properties(0).major <= 7.5: backends.append(SDPBackend.MATH) @@ -317,7 +314,6 @@ class AsymmetricAttention(nn.Module): ) return x, y -#@torch.compile(disable=not COMPILE_MMDIT_BLOCK) class AsymmetricJointBlock(nn.Module): def __init__( self, @@ -441,7 +437,6 @@ class AsymmetricJointBlock(nn.Module): return y -#@torch.compile(disable=not COMPILE_FINAL_LAYER) class FinalLayer(nn.Module): """ The final layer of DiT. @@ -586,7 +581,6 @@ class AsymmDiTJoint(nn.Module): """ return self.x_embedder(x) # Convert BcTHW to BCN - #@torch.compile(disable=not COMPILE_MMDIT_BLOCK) def prepare( self, x: torch.Tensor, diff --git a/mochi_preview/dit/joint_model/mod_rmsnorm.py b/mochi_preview/dit/joint_model/mod_rmsnorm.py index ffbb4c8..b3e317c 100644 --- a/mochi_preview/dit/joint_model/mod_rmsnorm.py +++ b/mochi_preview/dit/joint_model/mod_rmsnorm.py @@ -18,6 +18,6 @@ class ModulatedRMSNorm(torch.autograd.Function): return x_modulated.type_as(x) - +@torch.compiler.disable() def modulated_rmsnorm(x, scale, eps=1e-6): return ModulatedRMSNorm.apply(x, scale, eps) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index cf957d3..ab92689 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List +from typing import Dict, List, Optional import torch import torch.nn.functional as F @@ -98,7 +98,8 @@ class T2VSynthMochiModel: dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, fp8_fastmode: bool = False, - attention_mode: str = "sdpa" + attention_mode: str = "sdpa", + compile_args: Optional[Dict] = None, ): super().__init__() self.device = device @@ -157,8 +158,17 @@ class T2VSynthMochiModel: from ..fp8_optimization import convert_fp8_linear convert_fp8_linear(model, torch.bfloat16) + model = model.eval().to(self.device) + + #torch.compile + if compile_args is not None: + if compile_args["compile_dit"]: + for i, block in enumerate(model.blocks): + model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) + if compile_args["compile_final_layer"]: + model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) + self.dit = model - self.dit.eval() vae_stats = json.load(open(vae_stats_path)) self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) diff --git a/nodes.py b/nodes.py index 64ad2b1..da7a149 100644 --- a/nodes.py +++ b/nodes.py @@ -68,6 +68,7 @@ class DownloadAndLoadMochiModel: }, "optional": { "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), + "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), }, } @@ -77,7 +78,7 @@ class DownloadAndLoadMochiModel: CATEGORY = "MochiWrapper" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" - def loadmodel(self, model, vae, precision, attention_mode, trigger=None): + def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -121,7 +122,8 @@ class DownloadAndLoadMochiModel: dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, - attention_mode=attention_mode + attention_mode=attention_mode, + compile_args=compile_args ) with (init_empty_weights() if is_accelerate_available else nullcontext()): vae = Decoder( @@ -161,6 +163,7 @@ class MochiModelLoader: }, "optional": { "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), + "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), }, } RETURN_TYPES = ("MOCHIMODEL",) @@ -168,7 +171,7 @@ class MochiModelLoader: FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - def loadmodel(self, model_name, precision, attention_mode, trigger=None): + def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -184,11 +187,42 @@ class MochiModelLoader: dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, - attention_mode=attention_mode + attention_mode=attention_mode, + compile_args=compile_args ) return (model, ) +class MochiTorchCompileSettings: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "backend": (["inductor","cudagraph"], {"default": "inductor"}), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}), + "compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}), + }, + } + RETURN_TYPES = ("MOCHICOMPILEARGS",) + RETURN_NAMES = ("torch_compile_args",) + FUNCTION = "loadmodel" + CATEGORY = "MochiWrapper" + DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" + + def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer): + + compile_args = { + "backend": backend, + "fullgraph": fullgraph, + "mode": mode, + "compile_dit": compile_dit, + "compile_final_layer": compile_final_layer, + } + + return (compile_args, ) + class MochiVAELoader: @classmethod def INPUT_TYPES(s): @@ -522,7 +556,8 @@ NODE_CLASS_MAPPINGS = { "MochiTextEncode": MochiTextEncode, "MochiModelLoader": MochiModelLoader, "MochiVAELoader": MochiVAELoader, - "MochiDecodeSpatialTiling": MochiDecodeSpatialTiling + "MochiDecodeSpatialTiling": MochiDecodeSpatialTiling, + "MochiTorchCompileSettings": MochiTorchCompileSettings } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadMochiModel": "(Down)load Mochi Model", @@ -531,5 +566,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MochiTextEncode": "Mochi TextEncode", "MochiModelLoader": "Mochi Model Loader", "MochiVAELoader": "Mochi VAE Loader", - "MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling" + "MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling", + "MochiTorchCompileSettings": "Mochi Torch Compile Settings" }