mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Bug] Fix DeepGEMM Attention Test (#26423)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
76afe4edf8
commit
9fb3ae4e6f
@ -184,6 +184,7 @@ ba = "ba"
|
|||||||
|
|
||||||
[tool.typos.type.py.extend-words]
|
[tool.typos.type.py.extend-words]
|
||||||
ba = "ba"
|
ba = "ba"
|
||||||
|
nd = "nd"
|
||||||
|
|
||||||
[tool.typos.type.cpp]
|
[tool.typos.type.cpp]
|
||||||
extend-glob = ["*.cu"]
|
extend-glob = ["*.cu"]
|
||||||
|
|||||||
@ -82,8 +82,7 @@ def _ref_fp8_mqa_logits(
|
|||||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||||
)
|
)
|
||||||
mask = mask_lo & mask_hi
|
mask = mask_lo & mask_hi
|
||||||
|
score = torch.einsum("mhd,nd->hmn", q, k)
|
||||||
score = torch.einsum("mhd,and->hmn", q, k)
|
|
||||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||||
logits = logits.masked_fill(~mask, float("-inf"))
|
logits = logits.masked_fill(~mask, float("-inf"))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user