[FEAT][ROCm]: Support AITER MLA on V1 Engine (#17523)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
This commit is contained in:
vllmellm 2025-05-09 10:42:05 +08:00 committed by GitHub
parent 376786fac1
commit 3c9396a64f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 269 additions and 14 deletions

View File

@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="7e1ed08"
ARG AITER_BRANCH="5a77249"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base

View File

@ -102,7 +102,10 @@ def test_env(
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == name
if use_v1 and name != "TRITON_MLA":
assert backend.get_name() == f"{name}_VLLM_V1"
else:
assert backend.get_name() == name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,

View File

@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"
assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
# If attention backend is None
# If use_mla is true
@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"
assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")

View File

@ -4,6 +4,9 @@ from typing import Optional
import torch
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
max_block_per_batch: int,
@ -30,6 +33,28 @@ def aiter_mla_decode_fwd(
kv_last_page_lens: Optional[torch.Tensor] = None,
logit_cap: float = 0.0,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
kv_buffer.view(
-1, 1, 1, q.shape[-1]),
o,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)
def mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd
mla_decode_fwd(q,
@ -40,3 +65,24 @@ def aiter_mla_decode_fwd(
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)
def mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass
if current_platform.is_rocm():
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
op_func=mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=mla_decode_fwd_fake,
tags=[torch.Tag.needs_fixed_stride_order])

View File

@ -1319,6 +1319,7 @@ class EngineArgs:
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"ROCM_AITER_MLA",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

View File

@ -145,7 +145,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(a1, dtype=torch.bf16)
return torch.empty_like(a1, dtype=hidden_states_dtype)
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,

View File

@ -39,7 +39,8 @@ class _Backend(enum.Enum):
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1

View File

@ -168,10 +168,15 @@ class RocmPlatform(Platform):
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
elif selected_backend == _Backend.ROCM_AITER_MLA:
elif selected_backend == _Backend.ROCM_AITER_MLA \
or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
if block_size == 1:
logger.info("Using AITER MLA backend.")
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
if use_v1:
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
logger.info("Using AITER MLA backend")
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"

View File

@ -496,11 +496,12 @@ class MLACommonMetadataBuilder(Generic[M]):
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
if self.aot_schedule:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)

View File

@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any, Optional
import torch
import vllm.envs as envs
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
# yapf: enable
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def __init__(self, runner):
super().__init__(runner)
max_model_len = self.runner.model_config.max_model_len
assert max_model_len == 32768,\
"AITER MLA requires max_model_len=32768"
assert self.runner.block_size == 1, "AITER MLA" \
"only supports block size 1."
def _get_paged_kv_tensors(
self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.runner.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=block_table.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]
paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
return (
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
)
def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
(
paged_kv_indices,
paged_kv_indptr,
paged_last_page_len,
) = self._get_paged_kv_tensors(block_table, seq_lens)
attn_metadata = AiterMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len)
return attn_metadata
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o)