mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:14:57 +08:00
[Bug] Fix DeepGEMM Attention Test (#26423)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
76afe4edf8
commit
9fb3ae4e6f
@ -184,6 +184,7 @@ ba = "ba"
|
||||
|
||||
[tool.typos.type.py.extend-words]
|
||||
ba = "ba"
|
||||
nd = "nd"
|
||||
|
||||
[tool.typos.type.cpp]
|
||||
extend-glob = ["*.cu"]
|
||||
|
||||
@ -82,8 +82,7 @@ def _ref_fp8_mqa_logits(
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
score = torch.einsum("mhd,nd->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user