mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-29 15:13:40 +08:00
[Bugfix][Hardware][AMD] Fix last_page_len calculation in AITER MLA decode
The paged_kv_last_page_len was incorrectly set to the full sequence length instead of 1. Since the AITER MLA kernel uses a block size of 1 (each page contains exactly 1 token), the last_page_len should always be 1. Previous code: paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device) For a sequence of 127 tokens, this would set last_page_len=127, telling the kernel the last page has 127 tokens when it only has 1. This bug could cause incorrect attention scores or memory access issues for sequences with prime-number lengths that aren't multiples of common block sizes. Fixed by setting last_page_len to 1 unconditionally, matching the kernel's block_size=1 configuration. Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
parent
d201807339
commit
8756b5ed15
@ -122,7 +122,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)
|
||||
# kernel block size is always 1, so each page has exactly 1 token.
|
||||
# last_page_len should always be 1 regardless of sequence length.
|
||||
paged_kv_last_page_len = torch.ones(
|
||||
num_reqs, dtype=seq_lens_device.dtype, device=device
|
||||
)
|
||||
|
||||
paged_kv_indptr = torch.cat(
|
||||
[
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user