Allow settin dynamo_cache_size_limit
This commit is contained in:
parent
4ef7df00c9
commit
d3287d61b7
@ -165,6 +165,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
raise ImportError("Flash RMSNorm not available.")
|
raise ImportError("Flash RMSNorm not available.")
|
||||||
elif rms_norm_func == "apex":
|
elif rms_norm_func == "apex":
|
||||||
from apex.normalization import FusedRMSNorm as ApexRMSNorm
|
from apex.normalization import FusedRMSNorm as ApexRMSNorm
|
||||||
|
@torch.compiler.disable()
|
||||||
class RMSNorm(ApexRMSNorm):
|
class RMSNorm(ApexRMSNorm):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -237,7 +238,6 @@ class AsymmetricAttention(nn.Module):
|
|||||||
skip_reshape=True
|
skip_reshape=True
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def run_attention(
|
def run_attention(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
|
|||||||
@ -140,7 +140,7 @@ class PatchEmbed(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.compiler.disable()
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-5, device=None):
|
def __init__(self, hidden_size, eps=1e-5, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -34,6 +34,7 @@ except:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
import torch._dynamo
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from comfy.utils import ProgressBar, load_torch_file
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
@ -161,6 +162,8 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
#torch.compile
|
#torch.compile
|
||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
|
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||||
|
log.info(f"Set dynamo cache size limit to {torch._dynamo.config.cache_size_limit}")
|
||||||
if compile_args["compile_dit"]:
|
if compile_args["compile_dit"]:
|
||||||
for i, block in enumerate(model.blocks):
|
for i, block in enumerate(model.blocks):
|
||||||
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
||||||
|
|||||||
4
nodes.py
4
nodes.py
@ -247,6 +247,7 @@ class MochiTorchCompileSettings:
|
|||||||
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
||||||
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
||||||
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||||
|
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
||||||
@ -255,7 +256,7 @@ class MochiTorchCompileSettings:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||||
|
|
||||||
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic):
|
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic, dynamo_cache_size_limit):
|
||||||
|
|
||||||
compile_args = {
|
compile_args = {
|
||||||
"backend": backend,
|
"backend": backend,
|
||||||
@ -264,6 +265,7 @@ class MochiTorchCompileSettings:
|
|||||||
"compile_dit": compile_dit,
|
"compile_dit": compile_dit,
|
||||||
"compile_final_layer": compile_final_layer,
|
"compile_final_layer": compile_final_layer,
|
||||||
"dynamic": dynamic,
|
"dynamic": dynamic,
|
||||||
|
"dynamo_cache_size_limit": dynamo_cache_size_limit,
|
||||||
}
|
}
|
||||||
|
|
||||||
return (compile_args, )
|
return (compile_args, )
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user