This commit is contained in:
kijai 2024-10-23 17:04:50 +03:00
parent db87f8e608
commit fb880273a0
3 changed files with 68 additions and 15 deletions

45
fp8_optimization.py Normal file
View File

@ -0,0 +1,45 @@
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
import torch
import torch.nn as nn
def fp8_linear_forward(cls, original_dtype, input):
weight_dtype = cls.weight.dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if len(input.shape) == 3:
if weight_dtype == torch.float8_e4m3fn:
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
else:
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn)
w = cls.weight.t()
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
scale_input = scale_weight
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
out_dtype = original_dtype
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
else:
cls.to(original_dtype)
out = cls.original_forward(input.to(original_dtype))
cls.to(original_dtype)
return out
else:
return cls.original_forward(input)
def convert_fp8_linear(module, original_dtype):
setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules():
if isinstance(module, nn.Linear):
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -113,14 +113,17 @@ class T2VSynthMochiModel:
def __init__(
self,
*,
device_id: int,
device: torch.device,
offload_device: torch.device,
vae_stats_path: str,
dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
fp8_fastmode: bool = False,
):
super().__init__()
t = Timer()
self.device = torch.device(device_id)
self.device = device
self.offload_device = offload_device
with t("construct_dit"):
from .dit.joint_model.asymm_models_joint import (
@ -162,6 +165,10 @@ class T2VSynthMochiModel:
param.data = param.data.to(weight_dtype)
else:
param.data = param.data.to(torch.bfloat16)
if fp8_fastmode:
from ..fp8_optimization import convert_fp8_linear
convert_fp8_linear(model, torch.bfloat16)
self.dit = model
self.dit.eval()
@ -211,7 +218,7 @@ class T2VSynthMochiModel:
caption_input_ids_t5, caption_attention_mask_t5
).last_hidden_state.detach().to(torch.float32)
)
self.t5_enc.to("cpu")
self.t5_enc.to(self.offload_device)
# Sometimes returns a tensor, othertimes a tuple, not sure why
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
@ -367,7 +374,7 @@ class T2VSynthMochiModel:
if batch_cfg:
z = z[:B]
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
self.dit.to("cpu")
self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
print("samples: ", samples.shape, samples.dtype, samples.device)

View File

@ -57,7 +57,7 @@ class DownloadAndLoadMochiModel:
],
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
),
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
{"default": "fp8_e4m3fn", }
),
},
@ -75,7 +75,7 @@ class DownloadAndLoadMochiModel:
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
# Transformer model
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
@ -107,10 +107,12 @@ class DownloadAndLoadMochiModel:
)
model = T2VSynthMochiModel(
device_id=0,
device=device,
offload_device=offload_device,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
)
with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder(
@ -241,14 +243,13 @@ class MochiDecode:
"vae": ("MOCHIVAE",),
"samples": ("LATENT", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
},
"optional": {
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}),
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
}
},
}
RETURN_TYPES = ("IMAGE",)
@ -256,7 +257,8 @@ class MochiDecode:
FUNCTION = "decode"
CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True):
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
samples = samples["samples"]
@ -279,8 +281,6 @@ class MochiDecode:
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
#7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199
self.num_latent_frames_batch_size = 6
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
@ -300,10 +300,10 @@ class MochiDecode:
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
rows = []
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
row = []
@ -324,6 +324,7 @@ class MochiDecode:
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
comfy_pbar.update(1)
result_rows = []
for i, row in enumerate(tqdm(rows, desc="Blending rows")):