spatial VAE decoder fixes
This commit is contained in:
parent
f0f939b20b
commit
d971a19410
@ -550,6 +550,7 @@ class Decoder(nn.Module):
|
|||||||
nonlinearity: str = "silu",
|
nonlinearity: str = "silu",
|
||||||
output_nonlinearity: str = "silu",
|
output_nonlinearity: str = "silu",
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -558,6 +559,7 @@ class Decoder(nn.Module):
|
|||||||
self.channel_multipliers = channel_multipliers
|
self.channel_multipliers = channel_multipliers
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.output_nonlinearity = output_nonlinearity
|
self.output_nonlinearity = output_nonlinearity
|
||||||
|
self.dtype = dtype
|
||||||
assert nonlinearity == "silu"
|
assert nonlinearity == "silu"
|
||||||
assert causal
|
assert causal
|
||||||
|
|
||||||
@ -718,18 +720,27 @@ def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tens
|
|||||||
def nearest_multiple(x: int, multiple: int) -> int:
|
def nearest_multiple(x: int, multiple: int) -> int:
|
||||||
return round(x / multiple) * multiple
|
return round(x / multiple) * multiple
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
def apply_tiled(
|
def apply_tiled(
|
||||||
fn: Callable[[torch.Tensor], torch.Tensor],
|
fn: Callable[[torch.Tensor], torch.Tensor],
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
num_tiles_w: int,
|
num_tiles_w: int,
|
||||||
num_tiles_h: int,
|
num_tiles_h: int,
|
||||||
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
|
overlap: int = 0, # Number of pixels of overlap between adjacent tiles.
|
||||||
# Use a factor of 2 times the latent downsample factor.
|
|
||||||
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
|
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
comfy_pbar: Optional[ProgressBar] = None,
|
||||||
):
|
):
|
||||||
|
if pbar is None:
|
||||||
|
total_tiles = num_tiles_w * num_tiles_h
|
||||||
|
pbar = tqdm(total=total_tiles)
|
||||||
|
comfy_pbar = ProgressBar(total_tiles)
|
||||||
if num_tiles_w == 1 and num_tiles_h == 1:
|
if num_tiles_w == 1 and num_tiles_h == 1:
|
||||||
return fn(x)
|
result = fn(x)
|
||||||
|
pbar.update(1)
|
||||||
|
comfy_pbar.update(1)
|
||||||
|
return result
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
num_tiles_w & (num_tiles_w - 1) == 0
|
num_tiles_w & (num_tiles_w - 1) == 0
|
||||||
@ -752,10 +763,10 @@ def apply_tiled(
|
|||||||
|
|
||||||
assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even"
|
assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even"
|
||||||
left = apply_tiled(
|
left = apply_tiled(
|
||||||
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
|
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar
|
||||||
)
|
)
|
||||||
right = apply_tiled(
|
right = apply_tiled(
|
||||||
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
|
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar
|
||||||
)
|
)
|
||||||
if left is None or right is None:
|
if left is None or right is None:
|
||||||
return None
|
return None
|
||||||
@ -774,10 +785,10 @@ def apply_tiled(
|
|||||||
|
|
||||||
assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even"
|
assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even"
|
||||||
top = apply_tiled(
|
top = apply_tiled(
|
||||||
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
|
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar
|
||||||
)
|
)
|
||||||
bottom = apply_tiled(
|
bottom = apply_tiled(
|
||||||
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
|
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar
|
||||||
)
|
)
|
||||||
if top is None or bottom is None:
|
if top is None or bottom is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
32
nodes.py
32
nodes.py
@ -240,6 +240,7 @@ class MochiVAELoader:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
|
"precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,12 +249,14 @@ class MochiVAELoader:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model_name, torch_compile_args=None):
|
def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
|
||||||
|
|
||||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
@ -271,25 +274,22 @@ class MochiVAELoader:
|
|||||||
nonlinearity="silu",
|
nonlinearity="silu",
|
||||||
output_nonlinearity="silu",
|
output_nonlinearity="silu",
|
||||||
causal=True,
|
causal=True,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
vae_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for key in vae_sd:
|
for name, param in vae.named_parameters():
|
||||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key])
|
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
|
||||||
else:
|
else:
|
||||||
vae.load_state_dict(vae_sd, strict=True)
|
vae.load_state_dict(vae_sd, strict=True)
|
||||||
vae.to(torch.bfloat16).to("cpu")
|
vae.to(dtype).to(offload_device)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
del vae_sd
|
del vae_sd
|
||||||
|
|
||||||
if torch_compile_args is not None:
|
if torch_compile_args is not None:
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
# for i, block in enumerate(vae.blocks):
|
|
||||||
# if "CausalUpsampleBlock" in str(type(block)):
|
|
||||||
# print("Compiling block", block)
|
|
||||||
vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
||||||
|
|
||||||
|
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class MochiTextEncode:
|
class MochiTextEncode:
|
||||||
@ -447,7 +447,7 @@ class MochiDecode:
|
|||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.to(torch.bfloat16).to(device)
|
samples = samples.to(vae.dtype).to(device)
|
||||||
|
|
||||||
B, C, T, H, W = samples.shape
|
B, C, T, H, W = samples.shape
|
||||||
|
|
||||||
@ -574,21 +574,27 @@ class MochiDecodeSpatialTiling:
|
|||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.to(torch.bfloat16).to(device)
|
samples = samples.to(vae.dtype).to(device)
|
||||||
|
|
||||||
B, C, T, H, W = samples.shape
|
B, C, T, H, W = samples.shape
|
||||||
|
|
||||||
|
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
decoded_list = []
|
decoded_list = []
|
||||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||||
if enable_vae_tiling:
|
if enable_vae_tiling:
|
||||||
from .mochi_preview.vae.model import apply_tiled
|
from .mochi_preview.vae.model import apply_tiled
|
||||||
logging.warning("Decoding with tiling...")
|
|
||||||
|
pbar = ProgressBar(T // per_batch)
|
||||||
for i in range(0, T, per_batch):
|
for i in range(0, T, per_batch):
|
||||||
|
if i >= T:
|
||||||
|
break
|
||||||
end_index = min(i + per_batch, T)
|
end_index = min(i + per_batch, T)
|
||||||
chunk = samples[:, :, i:end_index, :, :]
|
logging.info(f"Decoding {end_index - i} samples with tiling...")
|
||||||
|
chunk = samples[:, :, i:end_index, :, :]
|
||||||
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||||
print(frames.shape)
|
logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples")
|
||||||
|
pbar.update(1)
|
||||||
# Blend the first and last frames of each pair
|
# Blend the first and last frames of each pair
|
||||||
if len(decoded_list) > 0:
|
if len(decoded_list) > 0:
|
||||||
previous_frames = decoded_list[-1]
|
previous_frames = decoded_list[-1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user