Workaround pad problem on mps
When using `torch.nn.functional.pad` with tensor that size is larger than 2^16 (65526), the output tensor would be broken. This patch moves tensor to CPU to workaround the problem. It doesn't too much impacts in terms of speed of vea on mps.
This commit is contained in:
parent
78f9e7b896
commit
5ca4bbf319
@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def mps_safe_pad(input, pad, mode):
|
||||
if input.device.type == "mps" and input.numel() >= 2 ** 16:
|
||||
device = input.device
|
||||
input = input.to(device="cpu")
|
||||
output = F.pad(input, pad, mode=mode)
|
||||
return output.to(device=device)
|
||||
else:
|
||||
return F.pad(input, pad, mode=mode)
|
||||
|
||||
class ContextParallelConv3d(SafeConv3d):
|
||||
def __init__(
|
||||
@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
|
||||
# Apply padding.
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
if self.context_parallel:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
else:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user