diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 3471ee327cf8..7038d0868c7e 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -159,6 +159,13 @@ class GeluAndMulSparse(CustomOp): self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") + if current_platform.is_rocm() and approximate == "tanh": + # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile + logger.warning_once( + "[ROCm] Pytorch's native GELU with tanh approximation is currently " + "unstable and produces garbage. Fallback to 'none' approximation." + ) + self.approximate = "none" # Sparsity. if activation_sparsity == 0.0: @@ -209,6 +216,12 @@ class GeluAndMul(CustomOp): self.op = torch.ops._C.gelu_and_mul elif approximate == "tanh": self.op = torch.ops._C.gelu_tanh_and_mul + if current_platform.is_rocm() and approximate == "tanh": + logger.warning_once( + "[ROCm] PyTorch's native GELU with tanh approximation is unstable " + "with torch.compile. For native implementation, fallback to 'none' " + "approximation. The custom kernel implementation is unaffected." + ) elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops @@ -219,8 +232,12 @@ class GeluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" + # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile + approximate = self.approximate + if current_platform.is_rocm() and approximate == "tanh": + approximate = "none" d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -522,7 +539,16 @@ _ACTIVATION_REGISTRY = LazyDict( "gelu": lambda: nn.GELU(), "gelu_fast": lambda: FastGELU(), "gelu_new": lambda: NewGELU(), - "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"), + "gelu_pytorch_tanh": lambda: ( + # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile + logger.warning_once( + "[ROCm] PyTorch's native GELU with tanh approximation is unstable. " + "Falling back to GELU(approximate='none')." + ), + nn.GELU(approximate="none"), + )[1] + if current_platform.is_rocm() + else nn.GELU(approximate="tanh"), "relu": lambda: nn.ReLU(), "relu2": lambda: ReLUSquaredActivation(), "silu": lambda: nn.SiLU(),