Merge pull request #66 from niw/workaround_mps_pad_problem

Workaround pad problem on mps
This commit is contained in:
Jukka Seppänen 2024-11-05 13:48:45 +09:00 committed by GitHub
commit 21374934d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
raise NotImplementedError
def mps_safe_pad(input, pad, mode):
if input.device.type == "mps" and input.numel() >= 2 ** 16:
device = input.device
input = input.to(device="cpu")
output = F.pad(input, pad, mode=mode)
return output.to(device=device)
else:
return F.pad(input, pad, mode=mode)
class ContextParallelConv3d(SafeConv3d):
def __init__(
@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
# Apply padding.
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
if self.context_parallel:
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
else:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
return super().forward(x)