From 8756b5ed1568cf3676d7be5d39d022c11a36fbad Mon Sep 17 00:00:00 2001 From: c0de128 Date: Wed, 24 Dec 2025 07:44:41 -0600 Subject: [PATCH] [Bugfix][Hardware][AMD] Fix last_page_len calculation in AITER MLA decode The paged_kv_last_page_len was incorrectly set to the full sequence length instead of 1. Since the AITER MLA kernel uses a block size of 1 (each page contains exactly 1 token), the last_page_len should always be 1. Previous code: paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device) For a sequence of 127 tokens, this would set last_page_len=127, telling the kernel the last page has 127 tokens when it only has 1. This bug could cause incorrect attention scores or memory access issues for sequences with prime-number lengths that aren't multiples of common block sizes. Fixed by setting last_page_len to 1 unconditionally, matching the kernel's block_size=1 configuration. Signed-off-by: c0de128 --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index e8921f8a1c403..751726c15880d 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -122,7 +122,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ).unsqueeze(0) < seq_lens_device.unsqueeze(1) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device) + # kernel block size is always 1, so each page has exactly 1 token. + # last_page_len should always be 1 regardless of sequence length. + paged_kv_last_page_len = torch.ones( + num_reqs, dtype=seq_lens_device.dtype, device=device + ) paged_kv_indptr = torch.cat( [