[FEAT][ROCm]: Support AITER MLA (#15893)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
vllmellm 2025-04-23 00:31:13 +08:00 committed by GitHub
parent f34410715f
commit 30bc3e0f66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 668 additions and 30 deletions

View File

@ -19,45 +19,152 @@ def clear_cache():
_cached_get_attn_backend.cache_clear() _cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize( # Define MLA and non-MLA backends separately
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA"],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}
DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER"],
"hip": ["ROCM_FLASH"],
"cpu": ["TORCH_SDPA"],
}
DEVICE_MLA_BLOCK_SIZES = {
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
"hip": [16, 1], # HIP requires special handling for block_size=1
"cpu": [16] # CPU uses fixed block size from test cases
}
def generate_params():
params = []
for use_mla in [True, False]:
for device in ["cuda", "hip", "cpu"]:
backends = DEVICE_MLA_BACKENDS[
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
for name in backends:
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
16
]
for block_size in block_sizes:
params.append(
pytest.param(
device,
name,
use_mla,
block_size,
id=
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
))
return params
@pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params())
@pytest.mark.parametrize("use_v1", [True, False]) @pytest.mark.parametrize("use_v1", [True, False])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env( def test_env(
name: str,
use_v1: bool,
device: str, device: str,
name: str,
use_mla: bool,
block_size: int,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
): ):
"""Test that the attention selector can be set via environment variable. """Test attention backend selection with valid device-backend pairs."""
Note that we do not test FlashAttn because it is the default backend.
"""
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
CpuPlatform()): CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, backend = get_attn_backend(16, torch.float16, torch.float16,
16, False) block_size, False)
assert backend.get_name() == "TORCH_SDPA" assert backend.get_name() == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
RocmPlatform()): RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, if use_mla:
16, False) # Validate HIP MLA backend-block_size combinations
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" valid_combination = (
assert backend.get_name() == EXPECTED (name == "TRITON_MLA" and block_size != 1)
else: or (name == "ROCM_AITER_MLA" and block_size == 1))
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform", if valid_combination:
CudaPlatform()): backend = get_attn_backend(16,
backend = get_attn_backend(16, torch.float16, torch.float16,
torch.float16, 16, False) torch.float16,
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name block_size,
assert backend.get_name() == EXPECTED False,
use_mla=use_mla)
assert backend.get_name() == name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == expected
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
if use_mla:
if name == "FLASHMLA" and block_size == 64:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
# only on cuda platforms with specific capability.
is_supported, _ = is_flashmla_supported()
if not is_supported:
# if platform is not supported then skip this case.
pytest.skip()
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = ("TRITON_MLA_VLLM_V1"
if use_v1 else "TRITON_MLA")
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):

View File

@ -28,7 +28,34 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
assert (backend.get_name() == "ROCM_FLASH" assert (backend.get_name() == "ROCM_FLASH"
or backend.get_name() == "TRITON_ATTN_VLLM_V1") or backend.get_name() == "TRITON_ATTN_VLLM_V1")
# mla test for deepseek related # MLA test for deepseek related
# change the attention backend to triton MLA
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True) False, True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# If attention backend is None
# If use_mla is true
# The selected backend is triton MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"
# If attention backend is None
# If use_mla is true
# If VLLM_ROCM_USE_AITER is enabled
# The selected backend is ROCM_AITER_MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"

View File

@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
self._ops_advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions)
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops.advance_step_flashattn(num_seqs=num_seqs, ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries, num_queries=num_queries,
block_size=block_size, block_size=block_size,
input_tokens=model_input.input_tokens, input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions, input_positions=input_positions,
seq_lens=self.seq_lens_tensor, seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping, slot_mapping=self.slot_mapping,
block_tables=self.block_tables) block_tables=self.block_tables)
@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
BLOCK_TABLE_EXTENDER: list[list[int]] = []
def __init__(self, input_builder: "ModelInputForGPUBuilder"): def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder self.input_builder = input_builder
@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size) self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables( block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables) num_seqs, self.block_tables)
else: else:

View File

@ -0,0 +1,412 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Type, Union
import torch
import vllm._custom_ops as ops
import vllm.envs as envs
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
MLACommonState)
from vllm.attention.backends.utils import (compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
get_aiter_mla_metadata)
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@staticmethod
def get_impl_cls() -> Type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["AiterMLAState"]:
return AiterMLAState
@dataclass
class AiterMLAMetadata(MLACommonMetadata):
# The following 4 tensors are for current version of AITER MLA
block_table_bound: Optional[torch.Tensor] = None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self):
prefill_metadata = super().prefill_metadata
self._cached_prefill_metadata = prefill_metadata
if prefill_metadata is not None:
prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
prefill_metadata.paged_kv_indices = self.paged_kv_indices
prefill_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
prefill_metadata.block_table_bound = self.block_table_bound
# update the cache
self._cached_prefill_metadata = self.__class__(
**prefill_metadata.__dict__)
return self._cached_prefill_metadata
@property
def decode_metadata(self):
decode_metadata = super().decode_metadata
self._cached_decode_metadata = decode_metadata
if decode_metadata is not None:
decode_metadata.paged_kv_indptr = self.paged_kv_indptr
decode_metadata.paged_kv_indices = self.paged_kv_indices
decode_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
decode_metadata.block_table_bound = self.block_table_bound
# update the cache
self._cached_decode_metadata = self.__class__(
**decode_metadata.__dict__)
return self._cached_decode_metadata
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
block_table_bound=self.block_table_bound)
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
super().__init__(input_builder)
assert self.runner.model_config.max_model_len == 32768,\
"AITER MLA requires max model len to be set to 32768"
assert self.block_size == 1, "AITER MLA requires only block size 1."
def prepare(self):
super().prepare()
self.paged_kv_indices: list[int] = []
self.paged_kv_indptr: list[int] = [0]
self.paged_kv_last_page_lens: list[int] = []
self.total_blocks = 0
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block, input_positions) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
inter_data.input_positions):
self.input_positions.extend(input_positions)
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
if is_profile_run:
return
# Update paged_kv_* tensors only for non-profile run
block_table = block_tables[seq_id]
self._update_paged_kv_tensors(block_table, seq_len)
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_lens.append(last_page_len)
def build(self, seq_lens: list[int], query_lens: list[int],
cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
batch_size)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
if use_captured_graph:
last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size)
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
# For current version of AITER MLA
if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device=device,
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
device=device,
dtype=torch.int)
paged_kv_last_page_lens_tensor = torch.tensor(
self.paged_kv_last_page_lens, device=device, dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device=device,
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_lens_tensor = None
block_table_bound_tensor = None
metadata.paged_kv_indptr = paged_kv_indptr_tensor
metadata.paged_kv_indices = paged_kv_indices_tensor
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
metadata.block_table_bound = block_table_bound_tensor
return metadata
class AiterMLAState(MLACommonState[AiterMLAMetadata]):
@contextmanager
def graph_capture(self, max_batch_size: int):
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
max_batch_size=max_batch_size,
block_size=self.runner.block_size,
max_block_per_batch=self.runner.get_max_block_per_batch(),
device=self.runner.device)
self._paged_kv_indices_tensor = kv_indices
self._paged_kv_indptr_tensor = kv_indptr
self._paged_kv_last_page_lens_tensor = last_page_lens
with super().graph_capture(max_batch_size):
yield
del self._paged_kv_indices_tensor
del self._paged_kv_indptr_tensor
del self._paged_kv_last_page_lens_tensor
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:
metadata = super().graph_capture_get_metadata_for_batch(
batch_size, is_encoder_decoder_model)
paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
paged_kv_indices = self._paged_kv_indices_tensor
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
batch_size]
metadata.paged_kv_indptr = paged_kv_indptr
metadata.paged_kv_indices = paged_kv_indices
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
return metadata
def get_graph_input_buffers(self,
attn_metadata: AiterMLAMetadata,
is_encoder_decoder_model: bool = False):
input_buffers = super().get_graph_input_buffers(
attn_metadata, is_encoder_decoder_model)
input_buffers[
'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
input_buffers[
"paged_kv_indices"] = attn_metadata.\
decode_metadata.paged_kv_indices
input_buffers[
"paged_kv_last_page_lens"] = attn_metadata.\
decode_metadata.paged_kv_last_page_lens
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata: AiterMLAMetadata,
is_encoder_decoder_model: bool = False):
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
is_encoder_decoder_model)
num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
0]
input_buffers["paged_kv_indptr"].copy_(
attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
input_buffers["paged_kv_last_page_lens"].copy_(
attn_metadata.decode_metadata.paged_kv_last_page_lens,
non_blocking=True)
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
softmax_scale: float, return_softmax_lse: bool,
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_lens)
return self._v_up_proj_and_o_proj(o)

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
max_block_per_batch: int,
device: torch.device) -> tuple[torch.Tensor, ...]:
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
dtype=torch.int32,
device=device)
paged_kv_indptr = torch.zeros(max_batch_size + 1,
dtype=torch.int32,
device=device)
paged_kv_last_page_lens = torch.full((max_batch_size, ),
block_size,
dtype=torch.int32)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
logit_cap: float = 0.0,
):
from aiter.mla import mla_decode_fwd
mla_decode_fwd(q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)

View File

@ -1248,7 +1248,7 @@ class ModelConfig:
or getattr(self.hf_config, "is_matryoshka", False)) or getattr(self.hf_config, "is_matryoshka", False))
BlockSize = Literal[8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"] PrefixCachingHashAlgo = Literal["builtin", "sha256"]

View File

@ -79,6 +79,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
@ -558,6 +559,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")), ("true", "1")),
# Whether to use aiter mla ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")),
# use rocm skinny gemms # use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": "VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in

View File

@ -39,6 +39,7 @@ class _Backend(enum.Enum):
TRITON_ATTN_VLLM_V1 = enum.auto() TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto() XFORMERS = enum.auto()
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA = enum.auto() # Supported by V1

View File

@ -141,8 +141,36 @@ class RocmPlatform(Platform):
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,
use_mla) -> str: use_mla) -> str:
if use_mla: if use_mla:
logger.info("Using Triton MLA backend.") from vllm.attention.backends.rocm_aiter_mla import (
return "vllm.attention.backends.triton_mla.TritonMLABackend" is_aiter_mla_enabled)
if selected_backend is None:
selected_backend = (_Backend.ROCM_AITER_MLA if
is_aiter_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA)
if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
elif selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend.")
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)")
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
@ -317,4 +345,4 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_cu_count(cls, device_id: int = 0) -> int: def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties( return torch.cuda.get_device_properties(
device_id).multi_processor_count device_id).multi_processor_count