diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index e8921f8a1c403..751726c15880d 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -122,7 +122,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ).unsqueeze(0) < seq_lens_device.unsqueeze(1) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device) + # kernel block size is always 1, so each page has exactly 1 token. + # last_page_len should always be 1 regardless of sequence length. + paged_kv_last_page_len = torch.ones( + num_reqs, dtype=seq_lens_device.dtype, device=device + ) paged_kv_indptr = torch.cat( [