spatial VAE decoder fixes

This commit is contained in:
kijai 2024-10-30 22:06:55 +02:00
parent f0f939b20b
commit d971a19410
2 changed files with 38 additions and 21 deletions

View File

@ -550,6 +550,7 @@ class Decoder(nn.Module):
nonlinearity: str = "silu",
output_nonlinearity: str = "silu",
causal: bool = True,
dtype: torch.dtype = torch.float32,
**block_kwargs,
):
super().__init__()
@ -558,6 +559,7 @@ class Decoder(nn.Module):
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.output_nonlinearity = output_nonlinearity
self.dtype = dtype
assert nonlinearity == "silu"
assert causal
@ -718,18 +720,27 @@ def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tens
def nearest_multiple(x: int, multiple: int) -> int:
return round(x / multiple) * multiple
from tqdm import tqdm
from comfy.utils import ProgressBar
def apply_tiled(
fn: Callable[[torch.Tensor], torch.Tensor],
x: torch.Tensor,
num_tiles_w: int,
num_tiles_h: int,
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
# Use a factor of 2 times the latent downsample factor.
overlap: int = 0, # Number of pixels of overlap between adjacent tiles.
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
pbar: Optional[tqdm] = None,
comfy_pbar: Optional[ProgressBar] = None,
):
if pbar is None:
total_tiles = num_tiles_w * num_tiles_h
pbar = tqdm(total=total_tiles)
comfy_pbar = ProgressBar(total_tiles)
if num_tiles_w == 1 and num_tiles_h == 1:
return fn(x)
result = fn(x)
pbar.update(1)
comfy_pbar.update(1)
return result
assert (
num_tiles_w & (num_tiles_w - 1) == 0
@ -752,10 +763,10 @@ def apply_tiled(
assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even"
left = apply_tiled(
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar
)
right = apply_tiled(
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar
)
if left is None or right is None:
return None
@ -774,10 +785,10 @@ def apply_tiled(
assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even"
top = apply_tiled(
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar
)
bottom = apply_tiled(
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar
)
if top is None or bottom is None:
return None

View File

@ -240,6 +240,7 @@ class MochiVAELoader:
},
"optional": {
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
},
}
@ -248,12 +249,14 @@ class MochiVAELoader:
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, torch_compile_args=None):
def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
with (init_empty_weights() if is_accelerate_available else nullcontext()):
@ -271,25 +274,22 @@ class MochiVAELoader:
nonlinearity="silu",
output_nonlinearity="silu",
causal=True,
dtype=dtype,
)
vae_sd = load_torch_file(vae_path)
if is_accelerate_available:
for key in vae_sd:
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key])
for name, param in vae.named_parameters():
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
else:
vae.load_state_dict(vae_sd, strict=True)
vae.to(torch.bfloat16).to("cpu")
vae.to(dtype).to(offload_device)
vae.eval()
del vae_sd
if torch_compile_args is not None:
vae.to(device)
# for i, block in enumerate(vae.blocks):
# if "CausalUpsampleBlock" in str(type(block)):
# print("Compiling block", block)
vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
return (vae,)
class MochiTextEncode:
@ -447,7 +447,7 @@ class MochiDecode:
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
samples = samples["samples"]
samples = samples.to(torch.bfloat16).to(device)
samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape
@ -574,21 +574,27 @@ class MochiDecodeSpatialTiling:
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
samples = samples["samples"]
samples = samples.to(torch.bfloat16).to(device)
samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape
vae.to(device)
decoded_list = []
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
if enable_vae_tiling:
from .mochi_preview.vae.model import apply_tiled
logging.warning("Decoding with tiling...")
pbar = ProgressBar(T // per_batch)
for i in range(0, T, per_batch):
if i >= T:
break
end_index = min(i + per_batch, T)
chunk = samples[:, :, i:end_index, :, :]
logging.info(f"Decoding {end_index - i} samples with tiling...")
chunk = samples[:, :, i:end_index, :, :]
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
print(frames.shape)
logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples")
pbar.update(1)
# Blend the first and last frames of each pair
if len(decoded_list) > 0:
previous_frames = decoded_list[-1]