From 9fb3ae4e6fec167232367b55ec85065545bf379d Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 8 Oct 2025 12:23:41 -0400 Subject: [PATCH] [Bug] Fix DeepGEMM Attention Test (#26423) Signed-off-by: yewentao256 --- pyproject.toml | 1 + tests/kernels/attention/test_deepgemm_attention.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 704f28fa6536..471eed98f9ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,7 @@ ba = "ba" [tool.typos.type.py.extend-words] ba = "ba" +nd = "nd" [tool.typos.type.cpp] extend-glob = ["*.cu"] diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 4873afa649c9..f4b4fac84015 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -82,8 +82,7 @@ def _ref_fp8_mqa_logits( torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] ) mask = mask_lo & mask_hi - - score = torch.einsum("mhd,and->hmn", q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf"))