This commit is contained in:
kijai 2024-10-23 17:04:50 +03:00
parent db87f8e608
commit fb880273a0
3 changed files with 68 additions and 15 deletions

45
fp8_optimization.py Normal file
View File

@ -0,0 +1,45 @@
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
import torch
import torch.nn as nn
def fp8_linear_forward(cls, original_dtype, input):
weight_dtype = cls.weight.dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if len(input.shape) == 3:
if weight_dtype == torch.float8_e4m3fn:
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
else:
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn)
w = cls.weight.t()
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
scale_input = scale_weight
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
out_dtype = original_dtype
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
else:
cls.to(original_dtype)
out = cls.original_forward(input.to(original_dtype))
cls.to(original_dtype)
return out
else:
return cls.original_forward(input)
def convert_fp8_linear(module, original_dtype):
setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules():
if isinstance(module, nn.Linear):
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -113,14 +113,17 @@ class T2VSynthMochiModel:
def __init__( def __init__(
self, self,
*, *,
device_id: int, device: torch.device,
offload_device: torch.device,
vae_stats_path: str, vae_stats_path: str,
dit_checkpoint_path: str, dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn, weight_dtype: torch.dtype = torch.float8_e4m3fn,
fp8_fastmode: bool = False,
): ):
super().__init__() super().__init__()
t = Timer() t = Timer()
self.device = torch.device(device_id) self.device = device
self.offload_device = offload_device
with t("construct_dit"): with t("construct_dit"):
from .dit.joint_model.asymm_models_joint import ( from .dit.joint_model.asymm_models_joint import (
@ -162,6 +165,10 @@ class T2VSynthMochiModel:
param.data = param.data.to(weight_dtype) param.data = param.data.to(weight_dtype)
else: else:
param.data = param.data.to(torch.bfloat16) param.data = param.data.to(torch.bfloat16)
if fp8_fastmode:
from ..fp8_optimization import convert_fp8_linear
convert_fp8_linear(model, torch.bfloat16)
self.dit = model self.dit = model
self.dit.eval() self.dit.eval()
@ -211,7 +218,7 @@ class T2VSynthMochiModel:
caption_input_ids_t5, caption_attention_mask_t5 caption_input_ids_t5, caption_attention_mask_t5
).last_hidden_state.detach().to(torch.float32) ).last_hidden_state.detach().to(torch.float32)
) )
self.t5_enc.to("cpu") self.t5_enc.to(self.offload_device)
# Sometimes returns a tensor, othertimes a tuple, not sure why # Sometimes returns a tensor, othertimes a tuple, not sure why
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
@ -367,7 +374,7 @@ class T2VSynthMochiModel:
if batch_cfg: if batch_cfg:
z = z[:B] z = z[:B]
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
self.dit.to("cpu") self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
print("samples: ", samples.shape, samples.dtype, samples.device) print("samples: ", samples.shape, samples.dtype, samples.device)

View File

@ -57,7 +57,7 @@ class DownloadAndLoadMochiModel:
], ],
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", }, {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
), ),
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], "precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
{"default": "fp8_e4m3fn", } {"default": "fp8_e4m3fn", }
), ),
}, },
@ -75,7 +75,7 @@ class DownloadAndLoadMochiModel:
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.soft_empty_cache() mm.soft_empty_cache()
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
# Transformer model # Transformer model
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi') model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
@ -107,10 +107,12 @@ class DownloadAndLoadMochiModel:
) )
model = T2VSynthMochiModel( model = T2VSynthMochiModel(
device_id=0, device=device,
offload_device=offload_device,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path, dit_checkpoint_path=model_path,
weight_dtype=dtype, weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
) )
with (init_empty_weights() if is_accelerate_available else nullcontext()): with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder( vae = Decoder(
@ -241,14 +243,13 @@ class MochiDecode:
"vae": ("MOCHIVAE",), "vae": ("MOCHIVAE",),
"samples": ("LATENT", ), "samples": ("LATENT", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
}, "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
"optional": { "frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}),
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}), "tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}), "tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), },
}
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -256,7 +257,8 @@ class MochiDecode:
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True): def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
samples = samples["samples"] samples = samples["samples"]
@ -279,8 +281,6 @@ class MochiDecode:
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6 self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5 self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
#7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199
self.num_latent_frames_batch_size = 6
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8 self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8 self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
@ -300,10 +300,10 @@ class MochiDecode:
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately. # Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles. # The tiles have an overlap to avoid seams between tiles.
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
rows = [] rows = []
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"): for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
row = [] row = []
@ -324,6 +324,7 @@ class MochiDecode:
time.append(tile) time.append(tile)
row.append(torch.cat(time, dim=2)) row.append(torch.cat(time, dim=2))
rows.append(row) rows.append(row)
comfy_pbar.update(1)
result_rows = [] result_rows = []
for i, row in enumerate(tqdm(rows, desc="Blending rows")): for i, row in enumerate(tqdm(rows, desc="Blending rows")):