diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 55e6596797010..a4a880f13cf7e 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param( } } +template +__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, + float alpha, float limit) { + // clamp gate: min=None, max=limit + const float gate_f = (float)gate; + const float clamped_gate = gate_f > limit ? limit : gate_f; + + // clamp up: min=-limit, max=limit + const float up_f = (float)up; + const float clamped_up = + up_f > limit ? limit : (up_f < -limit ? -limit : up_f); + + // glu = gate * sigmoid(gate * alpha) + const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha)); + const float glu = clamped_gate * sigmoid_val; + + // (up + 1) * glu + return (T)((clamped_up + 1.0f) * glu); +} + +template +__global__ void swigluoai_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d, const float alpha, const float limit) { + const int64_t token_idx = blockIdx.x; + // TODO: Vectorize loads and stores. + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + // gate = x[..., ::2] (even indices) + const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); + // up = x[..., 1::2] (odd indices) + const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); + + out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); + } +} + } // namespace vllm #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ @@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param( PARAM); \ }); +#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ + vllm::swigluoai_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, ALPHA, \ + LIMIT); \ + }); + void fatrelu_and_mul(torch::Tensor& out, // [..., d], torch::Tensor& input, // [..., 2 * d] double threshold) { LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); } +void swigluoai_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + double alpha, double limit) { + LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); +} namespace vllm { // Element-wise activation kernel template. diff --git a/csrc/ops.h b/csrc/ops.h index 207291eceb169..8b41b95473a16 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, double threshold); +void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input, + double alpha = 1.702, double limit = 7.0); void gelu_new(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8c207be083d88..41e9bc8a5e010 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,6 +130,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); + ops.def( + "swigluoai_and_mul(Tensor! out, Tensor input, float alpha, float limit) " + "-> ()"); + ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul); + // GELU implementation used in GPT-2. ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCUDA, &gelu_new); diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 29c5e70a8ba85..ec5c60fd7b0e2 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, GeluAndMul, MulAndSilu, NewGELU, QuickGELU, - SiluAndMul) + SiluAndMul, SwigluOAIAndMul) from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,7 +25,15 @@ CUDA_DEVICES = [ @pytest.mark.parametrize( "activation", - ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]) + [ + "silu_and_mul", + "mul_and_silu", + "gelu", + "gelu_tanh", + "fatrelu", + "swigluoai_and_mul", + ], +) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -59,18 +67,43 @@ def test_act_and_mul( threshold = random.uniform(0, 1) layer = FatreluAndMul(threshold) fn = torch.ops._C.fatrelu_and_mul + elif activation == "swigluoai_and_mul": + layer = SwigluOAIAndMul() + fn = torch.ops._C.swigluoai_and_mul out = layer(x) ref_out = layer.forward_native(x) - # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are - # equivalent to the native PyTorch implementations, so we can do exact - # comparison. - torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) + if activation == "swigluoai_and_mul": + + rtol = { + #For fp16, change the relative tolerance from 1e-3 to 2e-3 + torch.float16: + 2e-3, + torch.bfloat16: + 2e-2, + torch.float: + 1.3e-6 + } + + def _get_rtol(output) -> float: + return rtol[output.dtype] + + torch.testing.assert_close(out, + ref_out, + atol=get_default_atol(out), + rtol=_get_rtol(out)) + else: + # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are + # equivalent to the native PyTorch implementations, so we can do exact + # comparison. + torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "fatrelu": opcheck(fn, (out, x, threshold)) + elif activation == "swigluoai_and_mul": + opcheck(fn, (out, x, layer.alpha, layer.limit)) else: opcheck(fn, (out, x)) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 7ce44174ead6d..5f89dadec8b83 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -239,6 +239,35 @@ class GeluAndMul(CustomOp): return f'approximate={repr(self.approximate)}' +@CustomOp.register("swigluoai_and_mul") +class SwigluOAIAndMul(CustomOp): + # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110 + def __init__(self, alpha: float = 1.702, limit: float = 7.0): + super().__init__() + self.alpha = alpha + self.limit = limit + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + + gate, up = x[..., ::2], x[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + return gated_output + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit) + return out + + def extra_repr(self) -> str: + return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" + + @CustomOp.register("gelu_new") class NewGELU(CustomOp): @@ -330,6 +359,7 @@ class ReLUSquaredActivation(CustomOp): return torch.square(F.relu(x)) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + #TODO : implement cuda kenrels return self.forward_native(x) @@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ - "gelu": lambda: GeluAndMul(), - "silu": lambda: SiluAndMul(), - "geglu": lambda: GeluAndMul(), + "gelu": + lambda: GeluAndMul(), + "silu": + lambda: SiluAndMul(), + "geglu": + lambda: GeluAndMul(), + "swigluoai_and_mul": + lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), }) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 98087a35e15c7..23ebad36daf2b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1633,17 +1633,6 @@ def fused_experts_impl( block_shape=block_shape, B_bias=w1_bias) - # TODO fused kernel - def swiglu_oai(gate_up): - alpha = 1.702 - limit = 7.0 - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - gated_output = (up + 1) * glu - return gated_output - # Activation function with multiplication if activation == "silu" and is_act_and_mul: torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1651,13 +1640,16 @@ def fused_experts_impl( elif activation == "gelu" and is_act_and_mul: torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + elif activation == "swigluoai" and is_act_and_mul: + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) # Activation function without multiplication elif activation == "silu": intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == "gelu": intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - elif activation == "swiglu_oai": - intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N)) + else: raise ValueError(f"Unsupported FusedMoe activation: {activation}, " f"with is_act_and_mul={is_act_and_mul}.") diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index deeb69bcad0ec..dca38a019e9b1 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -68,7 +68,7 @@ def _can_support_mxfp4(use_grouped_topk: bool = False, return not (use_grouped_topk or topk_group or num_expert_group or expert_map or custom_routing_function or e_score_correction_bias or apply_router_weight_on_input - or scoring_func != "softmax" or activation != "swiglu_oai" + or scoring_func != "softmax" or activation != "swigluoai" or expert_load_view or logical_to_physical_map or logical_replica_count) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7c7712dbe106e..2f5d9ddd9054f 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -159,7 +159,7 @@ class MLPBlock(torch.nn.Module): prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, - activation="swiglu_oai") + activation="swigluoai") def forward(self, x: torch.Tensor) -> torch.Tensor: t = self.norm(x)