Add MochiSigmaSchedule node, better denoise formula

This commit is contained in:
kijai 2024-11-01 19:36:40 +02:00
parent ec298a1d64
commit 85c996d7b8
2 changed files with 29 additions and 3 deletions

View File

@ -268,7 +268,7 @@ class T2VSynthMochiModel:
dtype=torch.float32,
)
if in_samples is not None:
z = z * sigma_schedule[0] + in_samples.to(self.device) * sigma_schedule[-2]
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],

View File

@ -44,6 +44,30 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
class MochiSigmaSchedule:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"num_steps": ("INT", {"default": 30, "tooltip": "Number of steps","min": 0, "max": 10000, "step": 1}),
"threshold_noise": ("FLOAT", {"default": 0.025, "min": 0.0, "max": 1.0, "step": 0.001}),
"linear_steps": ("INT", {"default": 15, "min": 1, "max": 10000, "step": 1, "tooltip": "Number of steps to scale linearly, default should be half the steps"}),
},
}
RETURN_TYPES = ("SIGMAS",)
RETURN_NAMES = ("sigmas",)
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, num_steps, threshold_noise, linear_steps=None):
sigma_schedule = linear_quadratic_schedule(num_steps, threshold_noise, linear_steps)
sigma_schedule = torch.tensor(sigma_schedule[:-1])
return (sigma_schedule, )
#region ModelLoading
class DownloadAndLoadMochiModel:
@classmethod
@ -808,7 +832,8 @@ NODE_CLASS_MAPPINGS = {
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
"MochiTorchCompileSettings": MochiTorchCompileSettings,
"MochiImageEncode": MochiImageEncode,
"MochiLatentPreview": MochiLatentPreview
"MochiLatentPreview": MochiLatentPreview,
"MochiSigmaSchedule": MochiSigmaSchedule
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
@ -821,5 +846,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
"MochiTorchCompileSettings": "Mochi Torch Compile Settings",
"MochiImageEncode": "Mochi Image Encode",
"MochiLatentPreview": "Mochi Latent Preview"
"MochiLatentPreview": "Mochi Latent Preview",
"MochiSigmaSchedule": "Mochi Sigma Schedule"
}