mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 05:01:47 +08:00
[ROCm][FEAT] Enable Full Graph Mode in AITER MLA V1 Attn Backend (Decode Phase only) (#20254)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
139508a418
commit
a1aafc827a
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -63,6 +63,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
@ -70,56 +71,83 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
def _get_paged_kv_tensors(
|
||||
self, block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# Preparing persistent buffers
|
||||
if self.runner.full_cuda_graph:
|
||||
device = self.runner.device
|
||||
max_num_reqs = self.runner.max_num_reqs
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
block_table.get_device_tensor().numel(
|
||||
), # max num pages possible
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
self.qo_indptr = torch.arange(0,
|
||||
max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
device = self.runner.device
|
||||
|
||||
mask = (torch.arange(block_table.size(1),
|
||||
dtype=block_table.dtype,
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table[mask]
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
qo_indptr = torch.arange(0,
|
||||
self._num_decodes + 1,
|
||||
step=1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
if self.runner.full_cuda_graph:
|
||||
num_reqs = self._num_decodes
|
||||
|
||||
return (
|
||||
paged_kv_indices,
|
||||
paged_kv_indptr,
|
||||
paged_kv_last_page_len,
|
||||
qo_indptr,
|
||||
)
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
|
||||
non_blocking=True)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
(
|
||||
paged_kv_indices,
|
||||
paged_kv_indptr,
|
||||
paged_last_page_len,
|
||||
qo_indptr,
|
||||
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
|
||||
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
|
||||
non_blocking=True)
|
||||
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
|
||||
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
|
||||
|
||||
self.paged_kv_last_page_len[:num_reqs].copy_(
|
||||
paged_kv_last_page_len, non_blocking=True)
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
qo_indptr = self.qo_indptr[:1 + num_reqs]
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(0,
|
||||
self._num_decodes + 1,
|
||||
step=1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_last_page_len,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user