[Kernel] Fuse computation of g and beta for Gated Delta Net (#28095)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-11-06 01:14:55 +08:00 committed by GitHub
parent 6fd0df8132
commit c18f88c6ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -551,10 +551,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec
)
beta = b.sigmoid()
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
@ -1289,12 +1286,13 @@ direct_register_custom_op(
)
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
@triton.jit
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
@ -1308,6 +1306,7 @@ def fused_gdn_gating_kernel(
mask = head_off < NUM_HEADS
blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask)
# If the model is loaded in fp16, without the .float() here, A might be -inf
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
@ -1316,20 +1315,42 @@ def fused_gdn_gating_kernel(
)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused computation of g and beta for Gated Delta Net.
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
beta_output = b.sigmoid()
TODO maybe use torch.compile to replace this triton kernel
"""
batch, num_heads = a.shape
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
fused_gdn_gating_kernel[grid](
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
num_heads,
beta,
threshold,
8,
num_warps=1,
)
return g
return g, beta_output