[Bug] Fix DeepGEMM Attention Test (#26423)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-10-08 12:23:41 -04:00 committed by GitHub
parent 76afe4edf8
commit 9fb3ae4e6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -184,6 +184,7 @@ ba = "ba"
[tool.typos.type.py.extend-words]
ba = "ba"
nd = "nd"
[tool.typos.type.cpp]
extend-glob = ["*.cu"]

View File

@ -82,8 +82,7 @@ def _ref_fp8_mqa_logits(
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi
score = torch.einsum("mhd,and->hmn", q, k)
score = torch.einsum("mhd,nd->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))