diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 8ea61bb..a1c6fb7 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -268,7 +268,7 @@ class T2VSynthMochiModel: dtype=torch.float32, ) if in_samples is not None: - z = z * sigma_schedule[0] + in_samples.to(self.device) * sigma_schedule[-2] + z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device) sample = { "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], diff --git a/nodes.py b/nodes.py index 69c2722..92268ce 100644 --- a/nodes.py +++ b/nodes.py @@ -44,6 +44,30 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): sigma_schedule = [1.0 - x for x in sigma_schedule] return sigma_schedule +class MochiSigmaSchedule: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "num_steps": ("INT", {"default": 30, "tooltip": "Number of steps","min": 0, "max": 10000, "step": 1}), + "threshold_noise": ("FLOAT", {"default": 0.025, "min": 0.0, "max": 1.0, "step": 0.001}), + "linear_steps": ("INT", {"default": 15, "min": 1, "max": 10000, "step": 1, "tooltip": "Number of steps to scale linearly, default should be half the steps"}), + + }, + } + RETURN_TYPES = ("SIGMAS",) + RETURN_NAMES = ("sigmas",) + FUNCTION = "loadmodel" + CATEGORY = "MochiWrapper" + DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" + + def loadmodel(self, num_steps, threshold_noise, linear_steps=None): + + sigma_schedule = linear_quadratic_schedule(num_steps, threshold_noise, linear_steps) + sigma_schedule = torch.tensor(sigma_schedule[:-1]) + + return (sigma_schedule, ) + #region ModelLoading class DownloadAndLoadMochiModel: @classmethod @@ -808,7 +832,8 @@ NODE_CLASS_MAPPINGS = { "MochiDecodeSpatialTiling": MochiDecodeSpatialTiling, "MochiTorchCompileSettings": MochiTorchCompileSettings, "MochiImageEncode": MochiImageEncode, - "MochiLatentPreview": MochiLatentPreview + "MochiLatentPreview": MochiLatentPreview, + "MochiSigmaSchedule": MochiSigmaSchedule } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadMochiModel": "(Down)load Mochi Model", @@ -821,5 +846,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling", "MochiTorchCompileSettings": "Mochi Torch Compile Settings", "MochiImageEncode": "Mochi Image Encode", - "MochiLatentPreview": "Mochi Latent Preview" + "MochiLatentPreview": "Mochi Latent Preview", + "MochiSigmaSchedule": "Mochi Sigma Schedule" }