diff --git a/pyproject.toml b/pyproject.toml index 704f28fa6536..471eed98f9ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,7 @@ ba = "ba" [tool.typos.type.py.extend-words] ba = "ba" +nd = "nd" [tool.typos.type.cpp] extend-glob = ["*.cu"] diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 4873afa649c9..f4b4fac84015 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -82,8 +82,7 @@ def _ref_fp8_mqa_logits( torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] ) mask = mask_lo & mask_hi - - score = torch.einsum("mhd,and->hmn", q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf"))