From c4768dcf47ae919257e31b49a03c00d383ba3c55 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Mon, 10 Nov 2025 05:26:35 +0800 Subject: [PATCH] [Kernel] Fix fused_gdn_gating (#28343) Signed-off-by: zjy0516 --- vllm/model_executor/models/qwen3_next.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 880655da3f0a5..55bbad7a8b275 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1367,8 +1367,10 @@ def fused_gdn_gating_kernel( blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) # compute beta_output = sigmoid(b) - blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32))) - tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store( + beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask + ) def fused_gdn_gating( @@ -1389,7 +1391,7 @@ def fused_gdn_gating( seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) - beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) + beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) fused_gdn_gating_kernel[grid]( g, beta_output,