diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 56f937e..e26add7 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d): raise NotImplementedError +def mps_safe_pad(input, pad, mode): + if input.device.type == "mps" and input.numel() >= 2 ** 16: + device = input.device + input = input.to(device="cpu") + output = F.pad(input, pad, mode=mode) + return output.to(device=device) + else: + return F.pad(input, pad, mode=mode) class ContextParallelConv3d(SafeConv3d): def __init__( @@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d): # Apply padding. mode = "constant" if self.padding_mode == "zeros" else self.padding_mode if self.context_parallel: - x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) else: - x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) + x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) return super().forward(x)