spatial VAE decoder fixes
This commit is contained in:
parent
f0f939b20b
commit
d971a19410
@ -550,6 +550,7 @@ class Decoder(nn.Module):
|
||||
nonlinearity: str = "silu",
|
||||
output_nonlinearity: str = "silu",
|
||||
causal: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -558,6 +559,7 @@ class Decoder(nn.Module):
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.output_nonlinearity = output_nonlinearity
|
||||
self.dtype = dtype
|
||||
assert nonlinearity == "silu"
|
||||
assert causal
|
||||
|
||||
@ -718,18 +720,27 @@ def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tens
|
||||
def nearest_multiple(x: int, multiple: int) -> int:
|
||||
return round(x / multiple) * multiple
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar
|
||||
def apply_tiled(
|
||||
fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
num_tiles_w: int,
|
||||
num_tiles_h: int,
|
||||
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
|
||||
# Use a factor of 2 times the latent downsample factor.
|
||||
overlap: int = 0, # Number of pixels of overlap between adjacent tiles.
|
||||
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
|
||||
pbar: Optional[tqdm] = None,
|
||||
comfy_pbar: Optional[ProgressBar] = None,
|
||||
):
|
||||
if pbar is None:
|
||||
total_tiles = num_tiles_w * num_tiles_h
|
||||
pbar = tqdm(total=total_tiles)
|
||||
comfy_pbar = ProgressBar(total_tiles)
|
||||
if num_tiles_w == 1 and num_tiles_h == 1:
|
||||
return fn(x)
|
||||
result = fn(x)
|
||||
pbar.update(1)
|
||||
comfy_pbar.update(1)
|
||||
return result
|
||||
|
||||
assert (
|
||||
num_tiles_w & (num_tiles_w - 1) == 0
|
||||
@ -752,10 +763,10 @@ def apply_tiled(
|
||||
|
||||
assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even"
|
||||
left = apply_tiled(
|
||||
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
|
||||
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar
|
||||
)
|
||||
right = apply_tiled(
|
||||
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
|
||||
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar
|
||||
)
|
||||
if left is None or right is None:
|
||||
return None
|
||||
@ -774,10 +785,10 @@ def apply_tiled(
|
||||
|
||||
assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even"
|
||||
top = apply_tiled(
|
||||
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
|
||||
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar
|
||||
)
|
||||
bottom = apply_tiled(
|
||||
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
|
||||
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar
|
||||
)
|
||||
if top is None or bottom is None:
|
||||
return None
|
||||
|
||||
32
nodes.py
32
nodes.py
@ -240,6 +240,7 @@ class MochiVAELoader:
|
||||
},
|
||||
"optional": {
|
||||
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
"precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -248,12 +249,14 @@ class MochiVAELoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def loadmodel(self, model_name, torch_compile_args=None):
|
||||
def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
|
||||
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
@ -271,25 +274,22 @@ class MochiVAELoader:
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
if is_accelerate_available:
|
||||
for key in vae_sd:
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key])
|
||||
for name, param in vae.named_parameters():
|
||||
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
|
||||
else:
|
||||
vae.load_state_dict(vae_sd, strict=True)
|
||||
vae.to(torch.bfloat16).to("cpu")
|
||||
vae.to(dtype).to(offload_device)
|
||||
vae.eval()
|
||||
del vae_sd
|
||||
|
||||
if torch_compile_args is not None:
|
||||
vae.to(device)
|
||||
# for i, block in enumerate(vae.blocks):
|
||||
# if "CausalUpsampleBlock" in str(type(block)):
|
||||
# print("Compiling block", block)
|
||||
vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
||||
|
||||
|
||||
return (vae,)
|
||||
|
||||
class MochiTextEncode:
|
||||
@ -447,7 +447,7 @@ class MochiDecode:
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
samples = samples.to(torch.bfloat16).to(device)
|
||||
samples = samples.to(vae.dtype).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
|
||||
@ -574,21 +574,27 @@ class MochiDecodeSpatialTiling:
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
samples = samples.to(torch.bfloat16).to(device)
|
||||
samples = samples.to(vae.dtype).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
|
||||
|
||||
vae.to(device)
|
||||
decoded_list = []
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||
if enable_vae_tiling:
|
||||
from .mochi_preview.vae.model import apply_tiled
|
||||
logging.warning("Decoding with tiling...")
|
||||
|
||||
pbar = ProgressBar(T // per_batch)
|
||||
for i in range(0, T, per_batch):
|
||||
if i >= T:
|
||||
break
|
||||
end_index = min(i + per_batch, T)
|
||||
chunk = samples[:, :, i:end_index, :, :]
|
||||
logging.info(f"Decoding {end_index - i} samples with tiling...")
|
||||
chunk = samples[:, :, i:end_index, :, :]
|
||||
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||
print(frames.shape)
|
||||
logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples")
|
||||
pbar.update(1)
|
||||
# Blend the first and last frames of each pair
|
||||
if len(decoded_list) > 0:
|
||||
previous_frames = decoded_list[-1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user