diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index d28bc065852db..188becb6ad6f0 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -119,14 +119,12 @@ class AttentionBackend(ABC): return True for supported_size in cls.supported_kernel_block_sizes: - is_multiple_of = ( - isinstance(supported_size, MultipleOf) - and block_size % supported_size.base == 0 - ) - is_int_equal = ( - isinstance(supported_size, int) and block_size == supported_size - ) - if is_multiple_of or is_int_equal: + if isinstance(supported_size, MultipleOf): + supported_size = supported_size.base + # With hybrid_blocks feature, the framework-level block size + # only needs to be a multiple of the kernel's requirement, + # even if the kernel requires a fixed block_size. + if block_size % supported_size == 0: return True return False diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index e1864526f02cc..6ccc1a341d56c 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -7,9 +7,8 @@ from typing import ClassVar import torch from vllm._aiter_ops import rocm_aiter_ops -from vllm.attention.backends.abstract import AttentionLayer +from vllm.attention.backends.abstract import AttentionLayer, MultipleOf from vllm.config import VllmConfig -from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -22,6 +21,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec class AiterMLABackend(MLACommonBackend): + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1] + @staticmethod def get_name() -> str: return "ROCM_AITER_MLA" @@ -71,9 +72,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv( - vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size - ) + # kernel block size is always 1. + max_num_pages_per_req = vllm_config.model_config.max_model_len max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req @@ -82,11 +82,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # so we can only use the persistent buffer if a cudagraph is actually # being used. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.block_table_remapping = torch.zeros( - [max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size], - dtype=torch.int32, - device=device, - ) self.paged_kv_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=device ) @@ -111,36 +106,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): num_decode_tokens: int, dcp_tot_seq_lens_device: torch.Tensor | None, ) -> AiterMLADecodeMetadata: - page_size = self.kv_cache_spec.block_size + # kernel block size is always 1, although the kv block size is not 1. device = self.device num_reqs = seq_lens_device.size(0) - bs, _ = block_table_tensor.shape - block_table_tensor = ( - block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size - ) - block_table_tensor = ( - block_table_tensor - + torch.arange( - 0, - page_size, - device=block_table_tensor.device, - dtype=block_table_tensor.dtype, - )[None, None, :] - ) - block_table_tensor = block_table_tensor.view(bs, -1) - # after remapping, we assume the block size already equals to 1 - - max_blk_size_per_req = block_table_tensor.shape[-1] mask = torch.arange( block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device ).unsqueeze(0) < seq_lens_device.unsqueeze(1) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = seq_lens_device % page_size - paged_kv_last_page_len = torch.where( - paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len - ) + paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device) paged_kv_indptr = torch.cat( [ @@ -151,12 +126,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) - self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_( - block_table_tensor, non_blocking=True - ) - block_table_tensor = self.block_table_remapping[ - :num_reqs, :max_blk_size_per_req - ] self.paged_kv_indices[:num_actual_pages].copy_( paged_kv_indices, non_blocking=True