diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index a4a880f13cf7e..55e6596797010 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -128,45 +128,6 @@ __global__ void act_and_mul_kernel_with_param( } } -template -__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, - float alpha, float limit) { - // clamp gate: min=None, max=limit - const float gate_f = (float)gate; - const float clamped_gate = gate_f > limit ? limit : gate_f; - - // clamp up: min=-limit, max=limit - const float up_f = (float)up; - const float clamped_up = - up_f > limit ? limit : (up_f < -limit ? -limit : up_f); - - // glu = gate * sigmoid(gate * alpha) - const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha)); - const float glu = clamped_gate * sigmoid_val; - - // (up + 1) * glu - return (T)((clamped_up + 1.0f) * glu); -} - -template -__global__ void swigluoai_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d, const float alpha, const float limit) { - const int64_t token_idx = blockIdx.x; - // TODO: Vectorize loads and stores. - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - // gate = x[..., ::2] (even indices) - const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); - // up = x[..., 1::2] (odd indices) - const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); - - out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); - } -} - } // namespace vllm #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ @@ -184,31 +145,11 @@ __global__ void swigluoai_and_mul_kernel( PARAM); \ }); -#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ - vllm::swigluoai_and_mul_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, ALPHA, \ - LIMIT); \ - }); - void fatrelu_and_mul(torch::Tensor& out, // [..., d], torch::Tensor& input, // [..., 2 * d] double threshold) { LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); } -void swigluoai_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., 2 * d] - double alpha, double limit) { - LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); -} namespace vllm { // Element-wise activation kernel template. diff --git a/csrc/ops.h b/csrc/ops.h index 8b41b95473a16..207291eceb169 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -138,8 +138,6 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, double threshold); -void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input, - double alpha = 1.702, double limit = 7.0); void gelu_new(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 41e9bc8a5e010..8c207be083d88 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,11 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); - ops.def( - "swigluoai_and_mul(Tensor! out, Tensor input, float alpha, float limit) " - "-> ()"); - ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul); - // GELU implementation used in GPT-2. ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCUDA, &gelu_new); diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index ec5c60fd7b0e2..29c5e70a8ba85 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, GeluAndMul, MulAndSilu, NewGELU, QuickGELU, - SiluAndMul, SwigluOAIAndMul) + SiluAndMul) from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,15 +25,7 @@ CUDA_DEVICES = [ @pytest.mark.parametrize( "activation", - [ - "silu_and_mul", - "mul_and_silu", - "gelu", - "gelu_tanh", - "fatrelu", - "swigluoai_and_mul", - ], -) + ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -67,43 +59,18 @@ def test_act_and_mul( threshold = random.uniform(0, 1) layer = FatreluAndMul(threshold) fn = torch.ops._C.fatrelu_and_mul - elif activation == "swigluoai_and_mul": - layer = SwigluOAIAndMul() - fn = torch.ops._C.swigluoai_and_mul out = layer(x) ref_out = layer.forward_native(x) - if activation == "swigluoai_and_mul": - - rtol = { - #For fp16, change the relative tolerance from 1e-3 to 2e-3 - torch.float16: - 2e-3, - torch.bfloat16: - 2e-2, - torch.float: - 1.3e-6 - } - - def _get_rtol(output) -> float: - return rtol[output.dtype] - - torch.testing.assert_close(out, - ref_out, - atol=get_default_atol(out), - rtol=_get_rtol(out)) - else: - # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are - # equivalent to the native PyTorch implementations, so we can do exact - # comparison. - torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) + # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are + # equivalent to the native PyTorch implementations, so we can do exact + # comparison. + torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "fatrelu": opcheck(fn, (out, x, threshold)) - elif activation == "swigluoai_and_mul": - opcheck(fn, (out, x, layer.alpha, layer.limit)) else: opcheck(fn, (out, x)) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 5f89dadec8b83..7ce44174ead6d 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -239,35 +239,6 @@ class GeluAndMul(CustomOp): return f'approximate={repr(self.approximate)}' -@CustomOp.register("swigluoai_and_mul") -class SwigluOAIAndMul(CustomOp): - # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110 - def __init__(self, alpha: float = 1.702, limit: float = 7.0): - super().__init__() - self.alpha = alpha - self.limit = limit - - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward().""" - - gate, up = x[..., ::2], x[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - return gated_output - - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit) - return out - - def extra_repr(self) -> str: - return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" - - @CustomOp.register("gelu_new") class NewGELU(CustomOp): @@ -359,7 +330,6 @@ class ReLUSquaredActivation(CustomOp): return torch.square(F.relu(x)) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - #TODO : implement cuda kenrels return self.forward_native(x) @@ -436,14 +406,9 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ - "gelu": - lambda: GeluAndMul(), - "silu": - lambda: SiluAndMul(), - "geglu": - lambda: GeluAndMul(), - "swigluoai_and_mul": - lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), + "gelu": lambda: GeluAndMul(), + "silu": lambda: SiluAndMul(), + "geglu": lambda: GeluAndMul(), }) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 23ebad36daf2b..98087a35e15c7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1633,6 +1633,17 @@ def fused_experts_impl( block_shape=block_shape, B_bias=w1_bias) + # TODO fused kernel + def swiglu_oai(gate_up): + alpha = 1.702 + limit = 7.0 + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + gated_output = (up + 1) * glu + return gated_output + # Activation function with multiplication if activation == "silu" and is_act_and_mul: torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1640,16 +1651,13 @@ def fused_experts_impl( elif activation == "gelu" and is_act_and_mul: torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - elif activation == "swigluoai" and is_act_and_mul: - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) # Activation function without multiplication elif activation == "silu": intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == "gelu": intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - + elif activation == "swiglu_oai": + intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N)) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}, " f"with is_act_and_mul={is_act_and_mul}.") diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index dca38a019e9b1..deeb69bcad0ec 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -68,7 +68,7 @@ def _can_support_mxfp4(use_grouped_topk: bool = False, return not (use_grouped_topk or topk_group or num_expert_group or expert_map or custom_routing_function or e_score_correction_bias or apply_router_weight_on_input - or scoring_func != "softmax" or activation != "swigluoai" + or scoring_func != "softmax" or activation != "swiglu_oai" or expert_load_view or logical_to_physical_map or logical_replica_count) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 2f5d9ddd9054f..7c7712dbe106e 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -159,7 +159,7 @@ class MLPBlock(torch.nn.Module): prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, - activation="swigluoai") + activation="swiglu_oai") def forward(self, x: torch.Tensor) -> torch.Tensor: t = self.norm(x)