Merge pull request #66 from niw/workaround_mps_pad_problem
Workaround pad problem on mps
This commit is contained in:
commit
21374934d3
@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
|
|||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def mps_safe_pad(input, pad, mode):
|
||||||
|
if input.device.type == "mps" and input.numel() >= 2 ** 16:
|
||||||
|
device = input.device
|
||||||
|
input = input.to(device="cpu")
|
||||||
|
output = F.pad(input, pad, mode=mode)
|
||||||
|
return output.to(device=device)
|
||||||
|
else:
|
||||||
|
return F.pad(input, pad, mode=mode)
|
||||||
|
|
||||||
class ContextParallelConv3d(SafeConv3d):
|
class ContextParallelConv3d(SafeConv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
|
|||||||
# Apply padding.
|
# Apply padding.
|
||||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
if self.context_parallel:
|
if self.context_parallel:
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
else:
|
else:
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||||
|
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user