Allow settin dynamo_cache_size_limit

This commit is contained in:
kijai 2024-11-06 23:38:07 +02:00
parent 4ef7df00c9
commit d3287d61b7
4 changed files with 8 additions and 3 deletions

View File

@ -165,6 +165,7 @@ class AsymmetricAttention(nn.Module):
raise ImportError("Flash RMSNorm not available.") raise ImportError("Flash RMSNorm not available.")
elif rms_norm_func == "apex": elif rms_norm_func == "apex":
from apex.normalization import FusedRMSNorm as ApexRMSNorm from apex.normalization import FusedRMSNorm as ApexRMSNorm
@torch.compiler.disable()
class RMSNorm(ApexRMSNorm): class RMSNorm(ApexRMSNorm):
pass pass
else: else:
@ -237,7 +238,6 @@ class AsymmetricAttention(nn.Module):
skip_reshape=True skip_reshape=True
) )
return out return out
def run_attention( def run_attention(
self, self,
q, q,

View File

@ -140,7 +140,7 @@ class PatchEmbed(nn.Module):
x = self.norm(x) x = self.norm(x)
return x return x
@torch.compiler.disable()
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None): def __init__(self, hidden_size, eps=1e-5, device=None):
super().__init__() super().__init__()

View File

@ -34,6 +34,7 @@ except:
import torch import torch
import torch.utils.data import torch.utils.data
import torch._dynamo
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
@ -161,6 +162,8 @@ class T2VSynthMochiModel:
#torch.compile #torch.compile
if compile_args is not None: if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
log.info(f"Set dynamo cache size limit to {torch._dynamo.config.cache_size_limit}")
if compile_args["compile_dit"]: if compile_args["compile_dit"]:
for i, block in enumerate(model.blocks): for i, block in enumerate(model.blocks):
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"]) model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])

View File

@ -247,6 +247,7 @@ class MochiTorchCompileSettings:
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}), "compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}), "compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
}, },
} }
RETURN_TYPES = ("MOCHICOMPILEARGS",) RETURN_TYPES = ("MOCHICOMPILEARGS",)
@ -255,7 +256,7 @@ class MochiTorchCompileSettings:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic): def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic, dynamo_cache_size_limit):
compile_args = { compile_args = {
"backend": backend, "backend": backend,
@ -264,6 +265,7 @@ class MochiTorchCompileSettings:
"compile_dit": compile_dit, "compile_dit": compile_dit,
"compile_final_layer": compile_final_layer, "compile_final_layer": compile_final_layer,
"dynamic": dynamic, "dynamic": dynamic,
"dynamo_cache_size_limit": dynamo_cache_size_limit,
} }
return (compile_args, ) return (compile_args, )