Update nodes.py
This commit is contained in:
parent
d971a19410
commit
ebd0f62d53
5
nodes.py
5
nodes.py
@ -578,10 +578,9 @@ class MochiDecodeSpatialTiling:
|
|||||||
|
|
||||||
B, C, T, H, W = samples.shape
|
B, C, T, H, W = samples.shape
|
||||||
|
|
||||||
|
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
decoded_list = []
|
decoded_list = []
|
||||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype):
|
||||||
if enable_vae_tiling:
|
if enable_vae_tiling:
|
||||||
from .mochi_preview.vae.model import apply_tiled
|
from .mochi_preview.vae.model import apply_tiled
|
||||||
|
|
||||||
@ -595,7 +594,7 @@ class MochiDecodeSpatialTiling:
|
|||||||
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||||
logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples")
|
logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples")
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
# Blend the first and last frames of each pair
|
# Blend the first and last frames of each pair
|
||||||
if len(decoded_list) > 0:
|
if len(decoded_list) > 0:
|
||||||
previous_frames = decoded_list[-1]
|
previous_frames = decoded_list[-1]
|
||||||
blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2
|
blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user