mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:14:57 +08:00
[ROCm] Fallback pytorch GELU with tanh approximation to GELU() (#29244)
Signed-off-by: Divakar Verma <divakar.verma@amd.com> Signed-off-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
c0dfc89485
commit
4b40924998
@ -159,6 +159,13 @@ class GeluAndMulSparse(CustomOp):
|
|||||||
self.approximate = approximate
|
self.approximate = approximate
|
||||||
if approximate not in ("none", "tanh"):
|
if approximate not in ("none", "tanh"):
|
||||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||||
|
if current_platform.is_rocm() and approximate == "tanh":
|
||||||
|
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||||
|
logger.warning_once(
|
||||||
|
"[ROCm] Pytorch's native GELU with tanh approximation is currently "
|
||||||
|
"unstable and produces garbage. Fallback to 'none' approximation."
|
||||||
|
)
|
||||||
|
self.approximate = "none"
|
||||||
|
|
||||||
# Sparsity.
|
# Sparsity.
|
||||||
if activation_sparsity == 0.0:
|
if activation_sparsity == 0.0:
|
||||||
@ -209,6 +216,12 @@ class GeluAndMul(CustomOp):
|
|||||||
self.op = torch.ops._C.gelu_and_mul
|
self.op = torch.ops._C.gelu_and_mul
|
||||||
elif approximate == "tanh":
|
elif approximate == "tanh":
|
||||||
self.op = torch.ops._C.gelu_tanh_and_mul
|
self.op = torch.ops._C.gelu_tanh_and_mul
|
||||||
|
if current_platform.is_rocm() and approximate == "tanh":
|
||||||
|
logger.warning_once(
|
||||||
|
"[ROCm] PyTorch's native GELU with tanh approximation is unstable "
|
||||||
|
"with torch.compile. For native implementation, fallback to 'none' "
|
||||||
|
"approximation. The custom kernel implementation is unaffected."
|
||||||
|
)
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
from vllm._ipex_ops import ipex_ops
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
|
||||||
@ -219,8 +232,12 @@ class GeluAndMul(CustomOp):
|
|||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
# TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
|
||||||
|
approximate = self.approximate
|
||||||
|
if current_platform.is_rocm() and approximate == "tanh":
|
||||||
|
approximate = "none"
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
@ -522,7 +539,16 @@ _ACTIVATION_REGISTRY = LazyDict(
|
|||||||
"gelu": lambda: nn.GELU(),
|
"gelu": lambda: nn.GELU(),
|
||||||
"gelu_fast": lambda: FastGELU(),
|
"gelu_fast": lambda: FastGELU(),
|
||||||
"gelu_new": lambda: NewGELU(),
|
"gelu_new": lambda: NewGELU(),
|
||||||
"gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"),
|
"gelu_pytorch_tanh": lambda: (
|
||||||
|
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||||
|
logger.warning_once(
|
||||||
|
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
|
||||||
|
"Falling back to GELU(approximate='none')."
|
||||||
|
),
|
||||||
|
nn.GELU(approximate="none"),
|
||||||
|
)[1]
|
||||||
|
if current_platform.is_rocm()
|
||||||
|
else nn.GELU(approximate="tanh"),
|
||||||
"relu": lambda: nn.ReLU(),
|
"relu": lambda: nn.ReLU(),
|
||||||
"relu2": lambda: ReLUSquaredActivation(),
|
"relu2": lambda: ReLUSquaredActivation(),
|
||||||
"silu": lambda: nn.SiLU(),
|
"silu": lambda: nn.SiLU(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user