From fb880273a00b445f0b900900a624f6a06a018d45 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 23 Oct 2024 17:04:50 +0300 Subject: [PATCH] update --- fp8_optimization.py | 45 ++++++++++++++++++++++++++++++++ mochi_preview/t2v_synth_mochi.py | 15 ++++++++--- nodes.py | 23 ++++++++-------- 3 files changed, 68 insertions(+), 15 deletions(-) create mode 100644 fp8_optimization.py diff --git a/fp8_optimization.py b/fp8_optimization.py new file mode 100644 index 0000000..b01ac91 --- /dev/null +++ b/fp8_optimization.py @@ -0,0 +1,45 @@ +#based on ComfyUI's and MinusZoneAI's fp8_linear optimization + +import torch +import torch.nn as nn + +def fp8_linear_forward(cls, original_dtype, input): + weight_dtype = cls.weight.dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if len(input.shape) == 3: + if weight_dtype == torch.float8_e4m3fn: + inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2) + else: + inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn) + w = cls.weight.t() + + scale_weight = torch.ones((1), device=input.device, dtype=torch.float32) + scale_input = scale_weight + + bias = cls.bias.to(original_dtype) if cls.bias is not None else None + out_dtype = original_dtype + + if bias is not None: + o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + else: + o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight) + + if isinstance(o, tuple): + o = o[0] + + return o.reshape((-1, input.shape[1], cls.weight.shape[0])) + else: + cls.to(original_dtype) + out = cls.original_forward(input.to(original_dtype)) + cls.to(original_dtype) + return out + else: + return cls.original_forward(input) + +def convert_fp8_linear(module, original_dtype): + setattr(module, "fp8_matmul_enabled", True) + for name, module in module.named_modules(): + if isinstance(module, nn.Linear): + original_forward = module.forward + setattr(module, "original_forward", original_forward) + setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index f696af0..bfa1b05 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -113,14 +113,17 @@ class T2VSynthMochiModel: def __init__( self, *, - device_id: int, + device: torch.device, + offload_device: torch.device, vae_stats_path: str, dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, + fp8_fastmode: bool = False, ): super().__init__() t = Timer() - self.device = torch.device(device_id) + self.device = device + self.offload_device = offload_device with t("construct_dit"): from .dit.joint_model.asymm_models_joint import ( @@ -162,6 +165,10 @@ class T2VSynthMochiModel: param.data = param.data.to(weight_dtype) else: param.data = param.data.to(torch.bfloat16) + + if fp8_fastmode: + from ..fp8_optimization import convert_fp8_linear + convert_fp8_linear(model, torch.bfloat16) self.dit = model self.dit.eval() @@ -211,7 +218,7 @@ class T2VSynthMochiModel: caption_input_ids_t5, caption_attention_mask_t5 ).last_hidden_state.detach().to(torch.float32) ) - self.t5_enc.to("cpu") + self.t5_enc.to(self.offload_device) # Sometimes returns a tensor, othertimes a tuple, not sure why # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) @@ -367,7 +374,7 @@ class T2VSynthMochiModel: if batch_cfg: z = z[:B] z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim - self.dit.to("cpu") + self.dit.to(self.offload_device) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) print("samples: ", samples.shape, samples.dtype, samples.device) diff --git a/nodes.py b/nodes.py index 8fcb47a..7954a8c 100644 --- a/nodes.py +++ b/nodes.py @@ -57,7 +57,7 @@ class DownloadAndLoadMochiModel: ], {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", }, ), - "precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], + "precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], {"default": "fp8_e4m3fn", } ), }, @@ -75,7 +75,7 @@ class DownloadAndLoadMochiModel: offload_device = mm.unet_offload_device() mm.soft_empty_cache() - dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] # Transformer model model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi') @@ -107,10 +107,12 @@ class DownloadAndLoadMochiModel: ) model = T2VSynthMochiModel( - device_id=0, + device=device, + offload_device=offload_device, vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, weight_dtype=dtype, + fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, ) with (init_empty_weights() if is_accelerate_available else nullcontext()): vae = Decoder( @@ -241,14 +243,13 @@ class MochiDecode: "vae": ("MOCHIVAE",), "samples": ("LATENT", ), "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), - }, - "optional": { + "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), + "frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}), "tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}), "tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}), "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), - "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), - } + }, } RETURN_TYPES = ("IMAGE",) @@ -256,7 +257,8 @@ class MochiDecode: FUNCTION = "decode" CATEGORY = "MochiWrapper" - def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True): + def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, + tile_overlap_factor_width, auto_tile_size, frame_batch_size): device = mm.get_torch_device() offload_device = mm.unet_offload_device() samples = samples["samples"] @@ -279,8 +281,6 @@ class MochiDecode: self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6 self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5 - #7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199 - self.num_latent_frames_batch_size = 6 self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8 self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8 @@ -300,10 +300,10 @@ class MochiDecode: blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) row_limit_height = self.tile_sample_min_height - blend_extent_height row_limit_width = self.tile_sample_min_width - blend_extent_width - frame_batch_size = self.num_latent_frames_batch_size # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. + comfy_pbar = ProgressBar(len(range(0, height, overlap_height))) rows = [] for i in tqdm(range(0, height, overlap_height), desc="Processing rows"): row = [] @@ -324,6 +324,7 @@ class MochiDecode: time.append(tile) row.append(torch.cat(time, dim=2)) rows.append(row) + comfy_pbar.update(1) result_rows = [] for i, row in enumerate(tqdm(rows, desc="Blending rows")):