mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:34:58 +08:00
[ROCm] Fallback pytorch GELU with tanh approximation to GELU() (#29244)
Signed-off-by: Divakar Verma <divakar.verma@amd.com> Signed-off-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
c0dfc89485
commit
4b40924998
@ -159,6 +159,13 @@ class GeluAndMulSparse(CustomOp):
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
if current_platform.is_rocm() and approximate == "tanh":
|
||||
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||
logger.warning_once(
|
||||
"[ROCm] Pytorch's native GELU with tanh approximation is currently "
|
||||
"unstable and produces garbage. Fallback to 'none' approximation."
|
||||
)
|
||||
self.approximate = "none"
|
||||
|
||||
# Sparsity.
|
||||
if activation_sparsity == 0.0:
|
||||
@ -209,6 +216,12 @@ class GeluAndMul(CustomOp):
|
||||
self.op = torch.ops._C.gelu_and_mul
|
||||
elif approximate == "tanh":
|
||||
self.op = torch.ops._C.gelu_tanh_and_mul
|
||||
if current_platform.is_rocm() and approximate == "tanh":
|
||||
logger.warning_once(
|
||||
"[ROCm] PyTorch's native GELU with tanh approximation is unstable "
|
||||
"with torch.compile. For native implementation, fallback to 'none' "
|
||||
"approximation. The custom kernel implementation is unaffected."
|
||||
)
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
@ -219,8 +232,12 @@ class GeluAndMul(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
# TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
|
||||
approximate = self.approximate
|
||||
if current_platform.is_rocm() and approximate == "tanh":
|
||||
approximate = "none"
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
@ -522,7 +539,16 @@ _ACTIVATION_REGISTRY = LazyDict(
|
||||
"gelu": lambda: nn.GELU(),
|
||||
"gelu_fast": lambda: FastGELU(),
|
||||
"gelu_new": lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"),
|
||||
"gelu_pytorch_tanh": lambda: (
|
||||
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||
logger.warning_once(
|
||||
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
|
||||
"Falling back to GELU(approximate='none')."
|
||||
),
|
||||
nn.GELU(approximate="none"),
|
||||
)[1]
|
||||
if current_platform.is_rocm()
|
||||
else nn.GELU(approximate="tanh"),
|
||||
"relu": lambda: nn.ReLU(),
|
||||
"relu2": lambda: ReLUSquaredActivation(),
|
||||
"silu": lambda: nn.SiLU(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user