Workaround pad problem on mps
When using `torch.nn.functional.pad` with tensor that size is larger than 2^16 (65526), the output tensor would be broken. This patch moves tensor to CPU to workaround the problem. It doesn't too much impacts in terms of speed of vea on mps.
This commit is contained in:
parent
78f9e7b896
commit
5ca4bbf319
@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
|
|||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def mps_safe_pad(input, pad, mode):
|
||||||
|
if input.device.type == "mps" and input.numel() >= 2 ** 16:
|
||||||
|
device = input.device
|
||||||
|
input = input.to(device="cpu")
|
||||||
|
output = F.pad(input, pad, mode=mode)
|
||||||
|
return output.to(device=device)
|
||||||
|
else:
|
||||||
|
return F.pad(input, pad, mode=mode)
|
||||||
|
|
||||||
class ContextParallelConv3d(SafeConv3d):
|
class ContextParallelConv3d(SafeConv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
|
|||||||
# Apply padding.
|
# Apply padding.
|
||||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
if self.context_parallel:
|
if self.context_parallel:
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
else:
|
else:
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||||
|
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user