diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 8ad4e542b45b6..d5f9dfaea0655 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, ClassVar, Optional import torch @@ -63,6 +63,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + full_cudagraph_supported: ClassVar[bool] = True # decode only def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -70,56 +71,83 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." - def _get_paged_kv_tensors( - self, block_table: torch.Tensor, - seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: + # Preparing persistent buffers + if self.runner.full_cuda_graph: + device = self.runner.device + max_num_reqs = self.runner.max_num_reqs + self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device=device) + self.paged_kv_indices = torch.zeros( + block_table.get_device_tensor().numel( + ), # max num pages possible + dtype=torch.int32, + device=device) + self.paged_kv_last_page_len = torch.zeros(max_num_reqs, + dtype=torch.int32, + device=device) + + self.qo_indptr = torch.arange(0, + max_num_reqs + 1, + dtype=torch.int32, + device=device) + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size device = self.runner.device - mask = (torch.arange(block_table.size(1), - dtype=block_table.dtype, + mask = (torch.arange(block_table_tensor.size(1), + dtype=block_table_tensor.dtype, device=device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] + paged_kv_indices = block_table_tensor[mask] + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) paged_kv_indptr = torch.cat([ torch.zeros(1, dtype=block_table_bounds.dtype, device=device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - paged_kv_last_page_len = seq_lens % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) - qo_indptr = torch.arange(0, - self._num_decodes + 1, - step=1, - dtype=torch.int32, - device=device) + if self.runner.full_cuda_graph: + num_reqs = self._num_decodes - return ( - paged_kv_indices, - paged_kv_indptr, - paged_kv_last_page_len, - qo_indptr, - ) + num_actual_pages = paged_kv_indices.size(0) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, + non_blocking=True) + self.paged_kv_indices[num_actual_pages:].fill_(-1) + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - ( - paged_kv_indices, - paged_kv_indptr, - paged_last_page_len, - qo_indptr, - ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) + self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, + non_blocking=True) + self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs] + + self.paged_kv_last_page_len[:num_reqs].copy_( + paged_kv_last_page_len, non_blocking=True) + self.paged_kv_last_page_len[num_reqs:].fill_(1) + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + + qo_indptr = self.qo_indptr[:1 + num_reqs] + + else: + qo_indptr = torch.arange(0, + self._num_decodes + 1, + step=1, + dtype=torch.int32, + device=device) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_last_page_len, + paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr) return attn_metadata