mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:15:01 +08:00
[ROCm][MLA] Support block-size > 1 for AITER MLA backend (#27224)
Signed-off-by: ganyi <ygan@amd.com> Co-authored-by: wuhuikx <hattie.wu@amd.com>
This commit is contained in:
parent
80c9275348
commit
6cae1e5332
@ -104,13 +104,6 @@ def test_env(
|
|||||||
16, torch.float16, None, block_size, use_mla=use_mla
|
16, torch.float16, None, block_size, use_mla=use_mla
|
||||||
)
|
)
|
||||||
assert f"The selected backend, {name}" in str(exc_info.value)
|
assert f"The selected backend, {name}" in str(exc_info.value)
|
||||||
elif name == "ROCM_AITER_MLA" and block_size != 1:
|
|
||||||
# ROCM_AITER_MLA only supports block_size == 1
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
get_attn_backend(
|
|
||||||
16, torch.float16, None, block_size, use_mla=use_mla
|
|
||||||
)
|
|
||||||
assert f"The selected backend, {name}" in str(exc_info.value)
|
|
||||||
else:
|
else:
|
||||||
# Valid backend-block_size combination
|
# Valid backend-block_size combination
|
||||||
backend = get_attn_backend(
|
backend = get_attn_backend(
|
||||||
|
|||||||
@ -252,16 +252,9 @@ class RocmPlatform(Platform):
|
|||||||
f"does not support block size {block_size}."
|
f"does not support block size {block_size}."
|
||||||
)
|
)
|
||||||
if selected_backend == _Backend.ROCM_AITER_MLA:
|
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||||
if block_size == 1:
|
logger.info("Using AITER MLA backend.")
|
||||||
logger.info("Using AITER MLA backend.")
|
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||||
return (
|
|
||||||
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f" The selected backend, {selected_backend.name},"
|
|
||||||
f"does not support block size {block_size}."
|
|
||||||
"(currently only supports block size 1)"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f" The selected backend, {selected_backend.name},"
|
f" The selected backend, {selected_backend.name},"
|
||||||
f"is not MLA type while requested for MLA backend."
|
f"is not MLA type while requested for MLA backend."
|
||||||
|
|||||||
@ -78,9 +78,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
|
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
|
||||||
)
|
)
|
||||||
assert self.kv_cache_spec.block_size == 1, (
|
|
||||||
"AITER MLAonly supports block size 1."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
max_num_pages_per_req = cdiv(
|
max_num_pages_per_req = cdiv(
|
||||||
@ -94,6 +91,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
# so we can only use the persistent buffer if a cudagraph is actually
|
# so we can only use the persistent buffer if a cudagraph is actually
|
||||||
# being used.
|
# being used.
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
|
self.block_table_remapping = torch.zeros(
|
||||||
|
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
self.paged_kv_indptr = torch.zeros(
|
self.paged_kv_indptr = torch.zeros(
|
||||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
@ -119,13 +121,29 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||||
) -> AiterMLADecodeMetadata:
|
) -> AiterMLADecodeMetadata:
|
||||||
page_size = self.kv_cache_spec.block_size
|
page_size = self.kv_cache_spec.block_size
|
||||||
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
|
|
||||||
device = self.device
|
device = self.device
|
||||||
num_reqs = seq_lens_device.size(0)
|
num_reqs = seq_lens_device.size(0)
|
||||||
|
bs, _ = block_table_tensor.shape
|
||||||
|
block_table_tensor = (
|
||||||
|
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
|
||||||
|
)
|
||||||
|
block_table_tensor = (
|
||||||
|
block_table_tensor
|
||||||
|
+ torch.arange(
|
||||||
|
0,
|
||||||
|
page_size,
|
||||||
|
device=block_table_tensor.device,
|
||||||
|
dtype=block_table_tensor.dtype,
|
||||||
|
)[None, None, :]
|
||||||
|
)
|
||||||
|
block_table_tensor = block_table_tensor.view(bs, -1)
|
||||||
|
|
||||||
|
# after remapping, we assume the block size already equals to 1
|
||||||
|
|
||||||
|
max_blk_size_per_req = block_table_tensor.shape[-1]
|
||||||
mask = torch.arange(
|
mask = torch.arange(
|
||||||
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
||||||
).unsqueeze(0) < block_table_bounds.unsqueeze(1)
|
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||||
paged_kv_indices = block_table_tensor[mask]
|
paged_kv_indices = block_table_tensor[mask]
|
||||||
|
|
||||||
paged_kv_last_page_len = seq_lens_device % page_size
|
paged_kv_last_page_len = seq_lens_device % page_size
|
||||||
@ -135,13 +153,19 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
|
|
||||||
paged_kv_indptr = torch.cat(
|
paged_kv_indptr = torch.cat(
|
||||||
[
|
[
|
||||||
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
|
||||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32),
|
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
num_actual_pages = paged_kv_indices.size(0)
|
num_actual_pages = paged_kv_indices.size(0)
|
||||||
|
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
|
||||||
|
block_table_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
block_table_tensor = self.block_table_remapping[
|
||||||
|
:num_reqs, :max_blk_size_per_req
|
||||||
|
]
|
||||||
|
|
||||||
self.paged_kv_indices[:num_actual_pages].copy_(
|
self.paged_kv_indices[:num_actual_pages].copy_(
|
||||||
paged_kv_indices, non_blocking=True
|
paged_kv_indices, non_blocking=True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user