[Kernel] Fuse computation of g and beta for Gated Delta Net (#28095)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-11-06 01:14:55 +08:00 committed by GitHub
parent 6fd0df8132
commit c18f88c6ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -551,10 +551,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec mixed_qkv_non_spec
) )
beta = b.sigmoid() g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
@ -1289,12 +1286,13 @@ direct_register_custom_op(
) )
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
@triton.jit @triton.jit
def fused_gdn_gating_kernel( def fused_gdn_gating_kernel(
g, g,
beta_output,
A_log, A_log,
a, a,
b,
dt_bias, dt_bias,
seq_len, seq_len,
NUM_HEADS: tl.constexpr, NUM_HEADS: tl.constexpr,
@ -1308,6 +1306,7 @@ def fused_gdn_gating_kernel(
mask = head_off < NUM_HEADS mask = head_off < NUM_HEADS
blk_A_log = tl.load(A_log + head_off, mask=mask) blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_a = tl.load(a + off, mask=mask) blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask)
# If the model is loaded in fp16, without the .float() here, A might be -inf # If the model is loaded in fp16, without the .float() here, A might be -inf
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
@ -1316,20 +1315,42 @@ def fused_gdn_gating_kernel(
) )
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
def fused_gdn_gating( def fused_gdn_gating(
A_log: torch.Tensor, A_log: torch.Tensor,
a: torch.Tensor, a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor, dt_bias: torch.Tensor,
beta: float = 1.0, beta: float = 1.0,
threshold: float = 20.0, threshold: float = 20.0,
) -> torch.Tensor: ) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused computation of g and beta for Gated Delta Net.
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
beta_output = b.sigmoid()
TODO maybe use torch.compile to replace this triton kernel
"""
batch, num_heads = a.shape batch, num_heads = a.shape
seq_len = 1 seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8)) grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
fused_gdn_gating_kernel[grid]( fused_gdn_gating_kernel[grid](
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
num_heads,
beta,
threshold,
8,
num_warps=1,
) )
return g return g, beta_output