update
This commit is contained in:
parent
db87f8e608
commit
fb880273a0
45
fp8_optimization.py
Normal file
45
fp8_optimization.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def fp8_linear_forward(cls, original_dtype, input):
|
||||||
|
weight_dtype = cls.weight.dtype
|
||||||
|
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
if len(input.shape) == 3:
|
||||||
|
if weight_dtype == torch.float8_e4m3fn:
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
|
||||||
|
else:
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn)
|
||||||
|
w = cls.weight.t()
|
||||||
|
|
||||||
|
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
scale_input = scale_weight
|
||||||
|
|
||||||
|
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
|
||||||
|
out_dtype = original_dtype
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
else:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
o = o[0]
|
||||||
|
|
||||||
|
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
|
||||||
|
else:
|
||||||
|
cls.to(original_dtype)
|
||||||
|
out = cls.original_forward(input.to(original_dtype))
|
||||||
|
cls.to(original_dtype)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return cls.original_forward(input)
|
||||||
|
|
||||||
|
def convert_fp8_linear(module, original_dtype):
|
||||||
|
setattr(module, "fp8_matmul_enabled", True)
|
||||||
|
for name, module in module.named_modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
original_forward = module.forward
|
||||||
|
setattr(module, "original_forward", original_forward)
|
||||||
|
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||||
@ -113,14 +113,17 @@ class T2VSynthMochiModel:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
device_id: int,
|
device: torch.device,
|
||||||
|
offload_device: torch.device,
|
||||||
vae_stats_path: str,
|
vae_stats_path: str,
|
||||||
dit_checkpoint_path: str,
|
dit_checkpoint_path: str,
|
||||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
fp8_fastmode: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
t = Timer()
|
t = Timer()
|
||||||
self.device = torch.device(device_id)
|
self.device = device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
|
||||||
with t("construct_dit"):
|
with t("construct_dit"):
|
||||||
from .dit.joint_model.asymm_models_joint import (
|
from .dit.joint_model.asymm_models_joint import (
|
||||||
@ -162,6 +165,10 @@ class T2VSynthMochiModel:
|
|||||||
param.data = param.data.to(weight_dtype)
|
param.data = param.data.to(weight_dtype)
|
||||||
else:
|
else:
|
||||||
param.data = param.data.to(torch.bfloat16)
|
param.data = param.data.to(torch.bfloat16)
|
||||||
|
|
||||||
|
if fp8_fastmode:
|
||||||
|
from ..fp8_optimization import convert_fp8_linear
|
||||||
|
convert_fp8_linear(model, torch.bfloat16)
|
||||||
|
|
||||||
self.dit = model
|
self.dit = model
|
||||||
self.dit.eval()
|
self.dit.eval()
|
||||||
@ -211,7 +218,7 @@ class T2VSynthMochiModel:
|
|||||||
caption_input_ids_t5, caption_attention_mask_t5
|
caption_input_ids_t5, caption_attention_mask_t5
|
||||||
).last_hidden_state.detach().to(torch.float32)
|
).last_hidden_state.detach().to(torch.float32)
|
||||||
)
|
)
|
||||||
self.t5_enc.to("cpu")
|
self.t5_enc.to(self.offload_device)
|
||||||
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
||||||
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
||||||
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
|
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
|
||||||
@ -367,7 +374,7 @@ class T2VSynthMochiModel:
|
|||||||
if batch_cfg:
|
if batch_cfg:
|
||||||
z = z[:B]
|
z = z[:B]
|
||||||
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
||||||
self.dit.to("cpu")
|
self.dit.to(self.offload_device)
|
||||||
|
|
||||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||||
print("samples: ", samples.shape, samples.dtype, samples.device)
|
print("samples: ", samples.shape, samples.dtype, samples.device)
|
||||||
|
|||||||
23
nodes.py
23
nodes.py
@ -57,7 +57,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
],
|
],
|
||||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
||||||
),
|
),
|
||||||
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
||||||
{"default": "fp8_e4m3fn", }
|
{"default": "fp8_e4m3fn", }
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
@ -75,7 +75,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
|
||||||
# Transformer model
|
# Transformer model
|
||||||
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
||||||
@ -107,10 +107,12 @@ class DownloadAndLoadMochiModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = T2VSynthMochiModel(
|
model = T2VSynthMochiModel(
|
||||||
device_id=0,
|
device=device,
|
||||||
|
offload_device=offload_device,
|
||||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
|
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||||
)
|
)
|
||||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
vae = Decoder(
|
vae = Decoder(
|
||||||
@ -241,14 +243,13 @@ class MochiDecode:
|
|||||||
"vae": ("MOCHIVAE",),
|
"vae": ("MOCHIVAE",),
|
||||||
"samples": ("LATENT", ),
|
"samples": ("LATENT", ),
|
||||||
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||||
},
|
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||||
"optional": {
|
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}),
|
||||||
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
|
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
|
||||||
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
|
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
|
||||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
},
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
@ -256,7 +257,8 @@ class MochiDecode:
|
|||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True):
|
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
|
||||||
|
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
@ -279,8 +281,6 @@ class MochiDecode:
|
|||||||
|
|
||||||
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
||||||
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
||||||
#7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199
|
|
||||||
self.num_latent_frames_batch_size = 6
|
|
||||||
|
|
||||||
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
|
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
|
||||||
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
|
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
|
||||||
@ -300,10 +300,10 @@ class MochiDecode:
|
|||||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||||
frame_batch_size = self.num_latent_frames_batch_size
|
|
||||||
|
|
||||||
# Split z into overlapping tiles and decode them separately.
|
# Split z into overlapping tiles and decode them separately.
|
||||||
# The tiles have an overlap to avoid seams between tiles.
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
|
||||||
rows = []
|
rows = []
|
||||||
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
|
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
|
||||||
row = []
|
row = []
|
||||||
@ -324,6 +324,7 @@ class MochiDecode:
|
|||||||
time.append(tile)
|
time.append(tile)
|
||||||
row.append(torch.cat(time, dim=2))
|
row.append(torch.cat(time, dim=2))
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
result_rows = []
|
result_rows = []
|
||||||
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user