Merge pull request #66 from niw/workaround_mps_pad_problem
Workaround pad problem on mps
This commit is contained in:
commit
21374934d3
@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def mps_safe_pad(input, pad, mode):
|
||||
if input.device.type == "mps" and input.numel() >= 2 ** 16:
|
||||
device = input.device
|
||||
input = input.to(device="cpu")
|
||||
output = F.pad(input, pad, mode=mode)
|
||||
return output.to(device=device)
|
||||
else:
|
||||
return F.pad(input, pad, mode=mode)
|
||||
|
||||
class ContextParallelConv3d(SafeConv3d):
|
||||
def __init__(
|
||||
@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
|
||||
# Apply padding.
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
if self.context_parallel:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
else:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user