[Bugfix][Hardware][AMD] Fix last_page_len calculation in AITER MLA decode

The paged_kv_last_page_len was incorrectly set to the full sequence length
instead of 1. Since the AITER MLA kernel uses a block size of 1 (each page
contains exactly 1 token), the last_page_len should always be 1.

Previous code:
  paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)

For a sequence of 127 tokens, this would set last_page_len=127, telling the
kernel the last page has 127 tokens when it only has 1.

This bug could cause incorrect attention scores or memory access issues for
sequences with prime-number lengths that aren't multiples of common block sizes.

Fixed by setting last_page_len to 1 unconditionally, matching the kernel's
block_size=1 configuration.

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
c0de128 2025-12-24 07:44:41 -06:00
parent d201807339
commit 8756b5ed15

View File

@ -122,7 +122,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)
# kernel block size is always 1, so each page has exactly 1 token.
# last_page_len should always be 1 regardless of sequence length.
paged_kv_last_page_len = torch.ones(
num_reqs, dtype=seq_lens_device.dtype, device=device
)
paged_kv_indptr = torch.cat(
[