mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 07:17:04 +08:00
[Kernel] Fix fused_gdn_gating (#28343)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
a65a934ebe
commit
c4768dcf47
@ -1367,8 +1367,10 @@ def fused_gdn_gating_kernel(
|
|||||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||||
# compute beta_output = sigmoid(b)
|
# compute beta_output = sigmoid(b)
|
||||||
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
|
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
|
||||||
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
|
tl.store(
|
||||||
|
beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fused_gdn_gating(
|
def fused_gdn_gating(
|
||||||
@ -1389,7 +1391,7 @@ def fused_gdn_gating(
|
|||||||
seq_len = 1
|
seq_len = 1
|
||||||
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
||||||
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||||
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
|
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
|
||||||
fused_gdn_gating_kernel[grid](
|
fused_gdn_gating_kernel[grid](
|
||||||
g,
|
g,
|
||||||
beta_output,
|
beta_output,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user