[ROCm][FEAT] Enable Full Graph Mode in AITER MLA V1 Attn Backend (Decode Phase only) (#20254)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-07-03 00:25:46 +08:00 committed by GitHub
parent 139508a418
commit a1aafc827a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, ClassVar, Optional
import torch
@ -63,6 +63,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
@ -70,56 +71,83 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
def _get_paged_kv_tensors(
self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
# Preparing persistent buffers
if self.runner.full_cuda_graph:
device = self.runner.device
max_num_reqs = self.runner.max_num_reqs
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.paged_kv_indices = torch.zeros(
block_table.get_device_tensor().numel(
), # max num pages possible
dtype=torch.int32,
device=device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=device)
self.qo_indptr = torch.arange(0,
max_num_reqs + 1,
dtype=torch.int32,
device=device)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device
mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
paged_kv_indptr = torch.cat([
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
if self.runner.full_cuda_graph:
num_reqs = self._num_decodes
return (
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
qo_indptr,
)
num_actual_pages = paged_kv_indices.size(0)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
non_blocking=True)
self.paged_kv_indices[num_actual_pages:].fill_(-1)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
(
paged_kv_indices,
paged_kv_indptr,
paged_last_page_len,
qo_indptr,
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
non_blocking=True)
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
self.paged_kv_last_page_len[:num_reqs].copy_(
paged_kv_last_page_len, non_blocking=True)
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
qo_indptr = self.qo_indptr[:1 + num_reqs]
else:
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr)
return attn_metadata