[Kernel] Fix fused_gdn_gating (#28343)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-11-10 05:26:35 +08:00 committed by GitHub
parent a65a934ebe
commit c4768dcf47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1367,8 +1367,10 @@ def fused_gdn_gating_kernel(
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(
beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
)
def fused_gdn_gating(
@ -1389,7 +1391,7 @@ def fused_gdn_gating(
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
fused_gdn_gating_kernel[grid](
g,
beta_output,