mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:54:33 +08:00
[Kernel] Fix fused_gdn_gating (#28343)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
a65a934ebe
commit
c4768dcf47
@ -1367,8 +1367,10 @@ def fused_gdn_gating_kernel(
|
||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||
# compute beta_output = sigmoid(b)
|
||||
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
|
||||
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
|
||||
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
|
||||
tl.store(
|
||||
beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
|
||||
)
|
||||
|
||||
|
||||
def fused_gdn_gating(
|
||||
@ -1389,7 +1391,7 @@ def fused_gdn_gating(
|
||||
seq_len = 1
|
||||
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
||||
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
|
||||
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
|
||||
fused_gdn_gating_kernel[grid](
|
||||
g,
|
||||
beta_output,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user