update
This commit is contained in:
parent
db87f8e608
commit
fb880273a0
45
fp8_optimization.py
Normal file
45
fp8_optimization.py
Normal file
@ -0,0 +1,45 @@
|
||||
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def fp8_linear_forward(cls, original_dtype, input):
|
||||
weight_dtype = cls.weight.dtype
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
if len(input.shape) == 3:
|
||||
if weight_dtype == torch.float8_e4m3fn:
|
||||
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
|
||||
else:
|
||||
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn)
|
||||
w = cls.weight.t()
|
||||
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
scale_input = scale_weight
|
||||
|
||||
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
|
||||
out_dtype = original_dtype
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
|
||||
else:
|
||||
cls.to(original_dtype)
|
||||
out = cls.original_forward(input.to(original_dtype))
|
||||
cls.to(original_dtype)
|
||||
return out
|
||||
else:
|
||||
return cls.original_forward(input)
|
||||
|
||||
def convert_fp8_linear(module, original_dtype):
|
||||
setattr(module, "fp8_matmul_enabled", True)
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
original_forward = module.forward
|
||||
setattr(module, "original_forward", original_forward)
|
||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||
@ -113,14 +113,17 @@ class T2VSynthMochiModel:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_id: int,
|
||||
device: torch.device,
|
||||
offload_device: torch.device,
|
||||
vae_stats_path: str,
|
||||
dit_checkpoint_path: str,
|
||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
fp8_fastmode: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
t = Timer()
|
||||
self.device = torch.device(device_id)
|
||||
self.device = device
|
||||
self.offload_device = offload_device
|
||||
|
||||
with t("construct_dit"):
|
||||
from .dit.joint_model.asymm_models_joint import (
|
||||
@ -162,6 +165,10 @@ class T2VSynthMochiModel:
|
||||
param.data = param.data.to(weight_dtype)
|
||||
else:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
|
||||
if fp8_fastmode:
|
||||
from ..fp8_optimization import convert_fp8_linear
|
||||
convert_fp8_linear(model, torch.bfloat16)
|
||||
|
||||
self.dit = model
|
||||
self.dit.eval()
|
||||
@ -211,7 +218,7 @@ class T2VSynthMochiModel:
|
||||
caption_input_ids_t5, caption_attention_mask_t5
|
||||
).last_hidden_state.detach().to(torch.float32)
|
||||
)
|
||||
self.t5_enc.to("cpu")
|
||||
self.t5_enc.to(self.offload_device)
|
||||
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
||||
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
||||
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
|
||||
@ -367,7 +374,7 @@ class T2VSynthMochiModel:
|
||||
if batch_cfg:
|
||||
z = z[:B]
|
||||
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
||||
self.dit.to("cpu")
|
||||
self.dit.to(self.offload_device)
|
||||
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
print("samples: ", samples.shape, samples.dtype, samples.device)
|
||||
|
||||
23
nodes.py
23
nodes.py
@ -57,7 +57,7 @@ class DownloadAndLoadMochiModel:
|
||||
],
|
||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
||||
),
|
||||
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
||||
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
||||
{"default": "fp8_e4m3fn", }
|
||||
),
|
||||
},
|
||||
@ -75,7 +75,7 @@ class DownloadAndLoadMochiModel:
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
|
||||
# Transformer model
|
||||
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
||||
@ -107,10 +107,12 @@ class DownloadAndLoadMochiModel:
|
||||
)
|
||||
|
||||
model = T2VSynthMochiModel(
|
||||
device_id=0,
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
)
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
vae = Decoder(
|
||||
@ -241,14 +243,13 @@ class MochiDecode:
|
||||
"vae": ("MOCHIVAE",),
|
||||
"samples": ("LATENT", ),
|
||||
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||
},
|
||||
"optional": {
|
||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}),
|
||||
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
|
||||
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
|
||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -256,7 +257,8 @@ class MochiDecode:
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True):
|
||||
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
|
||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
samples = samples["samples"]
|
||||
@ -279,8 +281,6 @@ class MochiDecode:
|
||||
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
||||
#7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199
|
||||
self.num_latent_frames_batch_size = 6
|
||||
|
||||
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
|
||||
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
|
||||
@ -300,10 +300,10 @@ class MochiDecode:
|
||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
|
||||
rows = []
|
||||
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
|
||||
row = []
|
||||
@ -324,6 +324,7 @@ class MochiDecode:
|
||||
time.append(tile)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
comfy_pbar.update(1)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user