diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 22b10f0571d1..24d972702c85 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -33,12 +33,25 @@ template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: - // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38 + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 const float f = (float) x; constexpr float ALPHA = M_SQRT1_2; return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); } +template +__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 + const float f = (float) x; + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + float x_cube = f * f * f; + float inner = BETA * (f + KAPPA * x_cube); + return (T) (0.5f * f * (1.0f + ::tanhf(inner))); +} + } // namespace vllm // Launch activation and gating kernel. @@ -73,6 +86,13 @@ void gelu_and_mul( LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } +void gelu_tanh_and_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); +} + namespace vllm { // Element-wise activation kernel template. diff --git a/csrc/ops.h b/csrc/ops.h index 249c7451bf73..53222972abb7 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -61,6 +61,10 @@ void gelu_and_mul( torch::Tensor& out, torch::Tensor& input); +void gelu_tanh_and_mul( + torch::Tensor& out, + torch::Tensor& input); + void gelu_new( torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 4b6ade756639..39384f08d928 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -25,7 +25,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def( "gelu_and_mul", &gelu_and_mul, - "Activation function used in GeGLU."); + "Activation function used in GeGLU with `none` approximation."); + ops.def( + "gelu_tanh_and_mul", + &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); ops.def( "gelu_new", &gelu_new, diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index e0dec144eba1..f78913f120aa 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -16,7 +16,7 @@ CUDA_DEVICES = [ ] -@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul]) +@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -24,7 +24,7 @@ CUDA_DEVICES = [ @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_act_and_mul( - activation: Type[torch.nn.Module], + activation: str, num_tokens: int, d: int, dtype: torch.dtype, @@ -36,7 +36,12 @@ def test_act_and_mul( torch.cuda.manual_seed(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) - layer = activation() + if activation == "silu": + layer = SiluAndMul() + elif activation == "gelu": + layer = GeluAndMul(approximate="none") + elif activation == "gelu_tanh": + layer = GeluAndMul(approximate="tanh") out = layer(x) ref_out = layer._forward(x) # The SiLU and GELU implementations are equivalent to the native PyTorch diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 5a3a7b2dbaee..3eb73ee109f5 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -47,16 +47,25 @@ class GeluAndMul(nn.Module): return: (batch_size, seq_len, d) or (num_tokens, d) """ + def __init__(self, approximate: str = "none"): + super().__init__() + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + def _forward(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 - return F.gelu(x[..., :d]) * x[..., d:] + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] def forward(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - ops.gelu_and_mul(out, x) + if self.approximate == "none": + ops.gelu_and_mul(out, x) + elif self.approximate == "tanh": + ops.gelu_tanh_and_mul(out, x) return out