mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 03:45:02 +08:00
[Kernel] Fuse computation of g and beta for Gated Delta Net (#28095)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
6fd0df8132
commit
c18f88c6ca
@ -551,10 +551,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
mixed_qkv_non_spec
|
mixed_qkv_non_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
beta = b.sigmoid()
|
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||||
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
|
||||||
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
|
|
||||||
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))
|
|
||||||
|
|
||||||
if spec_sequence_masks is not None:
|
if spec_sequence_masks is not None:
|
||||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||||
@ -1289,12 +1286,13 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_gdn_gating_kernel(
|
def fused_gdn_gating_kernel(
|
||||||
g,
|
g,
|
||||||
|
beta_output,
|
||||||
A_log,
|
A_log,
|
||||||
a,
|
a,
|
||||||
|
b,
|
||||||
dt_bias,
|
dt_bias,
|
||||||
seq_len,
|
seq_len,
|
||||||
NUM_HEADS: tl.constexpr,
|
NUM_HEADS: tl.constexpr,
|
||||||
@ -1308,6 +1306,7 @@ def fused_gdn_gating_kernel(
|
|||||||
mask = head_off < NUM_HEADS
|
mask = head_off < NUM_HEADS
|
||||||
blk_A_log = tl.load(A_log + head_off, mask=mask)
|
blk_A_log = tl.load(A_log + head_off, mask=mask)
|
||||||
blk_a = tl.load(a + off, mask=mask)
|
blk_a = tl.load(a + off, mask=mask)
|
||||||
|
blk_b = tl.load(b + off, mask=mask)
|
||||||
blk_bias = tl.load(dt_bias + head_off, mask=mask)
|
blk_bias = tl.load(dt_bias + head_off, mask=mask)
|
||||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||||
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
|
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
|
||||||
@ -1316,20 +1315,42 @@ def fused_gdn_gating_kernel(
|
|||||||
)
|
)
|
||||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||||
|
# compute beta_output = sigmoid(b)
|
||||||
|
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
|
||||||
|
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def fused_gdn_gating(
|
def fused_gdn_gating(
|
||||||
A_log: torch.Tensor,
|
A_log: torch.Tensor,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
dt_bias: torch.Tensor,
|
dt_bias: torch.Tensor,
|
||||||
beta: float = 1.0,
|
beta: float = 1.0,
|
||||||
threshold: float = 20.0,
|
threshold: float = 20.0,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Fused computation of g and beta for Gated Delta Net.
|
||||||
|
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||||
|
beta_output = b.sigmoid()
|
||||||
|
TODO maybe use torch.compile to replace this triton kernel
|
||||||
|
"""
|
||||||
batch, num_heads = a.shape
|
batch, num_heads = a.shape
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
||||||
g = torch.empty_like(a, dtype=torch.float32)
|
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||||
|
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
|
||||||
fused_gdn_gating_kernel[grid](
|
fused_gdn_gating_kernel[grid](
|
||||||
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
|
g,
|
||||||
|
beta_output,
|
||||||
|
A_log,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
dt_bias,
|
||||||
|
seq_len,
|
||||||
|
num_heads,
|
||||||
|
beta,
|
||||||
|
threshold,
|
||||||
|
8,
|
||||||
|
num_warps=1,
|
||||||
)
|
)
|
||||||
return g
|
return g, beta_output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user