Workaround pad problem on mps

When using `torch.nn.functional.pad` with tensor that size is
larger than 2^16 (65526), the output tensor would be broken.

This patch moves tensor to CPU to workaround the problem.
It doesn't too much impacts in terms of speed of vea on mps.
This commit is contained in:
Yoshimasa Niwa 2024-11-05 12:46:13 +09:00
parent 78f9e7b896
commit 5ca4bbf319

View File

@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
raise NotImplementedError
def mps_safe_pad(input, pad, mode):
if input.device.type == "mps" and input.numel() >= 2 ** 16:
device = input.device
input = input.to(device="cpu")
output = F.pad(input, pad, mode=mode)
return output.to(device=device)
else:
return F.pad(input, pad, mode=mode)
class ContextParallelConv3d(SafeConv3d):
def __init__(
@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
# Apply padding.
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
if self.context_parallel:
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
else:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
return super().forward(x)