diff --git a/nodes.py b/nodes.py index 76ac104..af6e74f 100644 --- a/nodes.py +++ b/nodes.py @@ -266,7 +266,9 @@ class MochiDecode: device = mm.get_torch_device() offload_device = mm.unet_offload_device() samples = samples["samples"] + samples = samples.to(torch.bfloat16).to(device) + def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for y in range(blend_extent):