mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 23:54:26 +08:00
[Kernel] Fuse computation of g and beta for Gated Delta Net (#28095)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
6fd0df8132
commit
c18f88c6ca
@ -551,10 +551,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_non_spec
|
||||
)
|
||||
|
||||
beta = b.sigmoid()
|
||||
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
|
||||
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
@ -1289,12 +1286,13 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
@triton.jit
|
||||
def fused_gdn_gating_kernel(
|
||||
g,
|
||||
beta_output,
|
||||
A_log,
|
||||
a,
|
||||
b,
|
||||
dt_bias,
|
||||
seq_len,
|
||||
NUM_HEADS: tl.constexpr,
|
||||
@ -1308,6 +1306,7 @@ def fused_gdn_gating_kernel(
|
||||
mask = head_off < NUM_HEADS
|
||||
blk_A_log = tl.load(A_log + head_off, mask=mask)
|
||||
blk_a = tl.load(a + off, mask=mask)
|
||||
blk_b = tl.load(b + off, mask=mask)
|
||||
blk_bias = tl.load(dt_bias + head_off, mask=mask)
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
|
||||
@ -1316,20 +1315,42 @@ def fused_gdn_gating_kernel(
|
||||
)
|
||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||
# compute beta_output = sigmoid(b)
|
||||
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
|
||||
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def fused_gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fused computation of g and beta for Gated Delta Net.
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
beta_output = b.sigmoid()
|
||||
TODO maybe use torch.compile to replace this triton kernel
|
||||
"""
|
||||
batch, num_heads = a.shape
|
||||
seq_len = 1
|
||||
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
||||
g = torch.empty_like(a, dtype=torch.float32)
|
||||
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
|
||||
fused_gdn_gating_kernel[grid](
|
||||
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
|
||||
g,
|
||||
beta_output,
|
||||
A_log,
|
||||
a,
|
||||
b,
|
||||
dt_bias,
|
||||
seq_len,
|
||||
num_heads,
|
||||
beta,
|
||||
threshold,
|
||||
8,
|
||||
num_warps=1,
|
||||
)
|
||||
return g
|
||||
return g, beta_output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user