mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[FEAT][ROCm]: Support AITER MLA (#15893)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
parent
f34410715f
commit
30bc3e0f66
@ -19,45 +19,152 @@ def clear_cache():
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
|
||||
# Define MLA and non-MLA backends separately
|
||||
DEVICE_MLA_BACKENDS = {
|
||||
"cuda": ["TRITON_MLA", "FLASHMLA"],
|
||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||
"cpu": [],
|
||||
}
|
||||
|
||||
DEVICE_REGULAR_ATTN_BACKENDS = {
|
||||
"cuda": ["XFORMERS", "FLASHINFER"],
|
||||
"hip": ["ROCM_FLASH"],
|
||||
"cpu": ["TORCH_SDPA"],
|
||||
}
|
||||
|
||||
DEVICE_MLA_BLOCK_SIZES = {
|
||||
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
||||
"hip": [16, 1], # HIP requires special handling for block_size=1
|
||||
"cpu": [16] # CPU uses fixed block size from test cases
|
||||
}
|
||||
|
||||
|
||||
def generate_params():
|
||||
params = []
|
||||
for use_mla in [True, False]:
|
||||
for device in ["cuda", "hip", "cpu"]:
|
||||
backends = DEVICE_MLA_BACKENDS[
|
||||
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
|
||||
for name in backends:
|
||||
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
|
||||
16
|
||||
]
|
||||
for block_size in block_sizes:
|
||||
params.append(
|
||||
pytest.param(
|
||||
device,
|
||||
name,
|
||||
use_mla,
|
||||
block_size,
|
||||
id=
|
||||
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
|
||||
))
|
||||
return params
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size",
|
||||
generate_params())
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||
def test_env(
|
||||
name: str,
|
||||
use_v1: bool,
|
||||
device: str,
|
||||
name: str,
|
||||
use_mla: bool,
|
||||
block_size: int,
|
||||
use_v1: bool,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that the attention selector can be set via environment variable.
|
||||
Note that we do not test FlashAttn because it is the default backend.
|
||||
"""
|
||||
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, name)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
block_size, False)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||
assert backend.get_name() == EXPECTED
|
||||
else:
|
||||
if name in ["XFORMERS", "FLASHINFER"]:
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16,
|
||||
torch.float16, 16, False)
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == EXPECTED
|
||||
if use_mla:
|
||||
# Validate HIP MLA backend-block_size combinations
|
||||
valid_combination = (
|
||||
(name == "TRITON_MLA" and block_size != 1)
|
||||
or (name == "ROCM_AITER_MLA" and block_size == 1))
|
||||
|
||||
if valid_combination:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert backend.get_name() == name
|
||||
else:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
if use_mla:
|
||||
if name == "FLASHMLA" and block_size == 64:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
|
||||
# only on cuda platforms with specific capability.
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
|
||||
if not is_supported:
|
||||
# if platform is not supported then skip this case.
|
||||
pytest.skip()
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = ("TRITON_MLA_VLLM_V1"
|
||||
if use_v1 else "TRITON_MLA")
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
@ -28,7 +28,34 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
assert (backend.get_name() == "ROCM_FLASH"
|
||||
or backend.get_name() == "TRITON_ATTN_VLLM_V1")
|
||||
|
||||
# mla test for deepseek related
|
||||
# MLA test for deepseek related
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# If VLLM_ROCM_USE_AITER is enabled
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
self._ops_advance_step(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions)
|
||||
|
||||
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
||||
block_size: int, input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor) -> None:
|
||||
# here we use advance_step_flashinfo to update the paged_kv_* tensors
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
input_tokens=input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
input_positions=input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
BLOCK_TABLE_EXTENDER: list[list[int]] = []
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
|
||||
cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
|
||||
412
vllm/attention/backends/rocm_aiter_mla.py
Normal file
412
vllm/attention/backends/rocm_aiter_mla.py
Normal file
@ -0,0 +1,412 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
MLACommonState)
|
||||
from vllm.attention.backends.utils import (compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
|
||||
get_aiter_mla_metadata)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
|
||||
return AiterMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["AiterMLAState"]:
|
||||
return AiterMLAState
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLAMetadata(MLACommonMetadata):
|
||||
# The following 4 tensors are for current version of AITER MLA
|
||||
block_table_bound: Optional[torch.Tensor] = None
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_lens: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self):
|
||||
prefill_metadata = super().prefill_metadata
|
||||
self._cached_prefill_metadata = prefill_metadata
|
||||
|
||||
if prefill_metadata is not None:
|
||||
prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
|
||||
prefill_metadata.paged_kv_indices = self.paged_kv_indices
|
||||
prefill_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
prefill_metadata.block_table_bound = self.block_table_bound
|
||||
|
||||
# update the cache
|
||||
self._cached_prefill_metadata = self.__class__(
|
||||
**prefill_metadata.__dict__)
|
||||
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self):
|
||||
decode_metadata = super().decode_metadata
|
||||
|
||||
self._cached_decode_metadata = decode_metadata
|
||||
|
||||
if decode_metadata is not None:
|
||||
decode_metadata.paged_kv_indptr = self.paged_kv_indptr
|
||||
decode_metadata.paged_kv_indices = self.paged_kv_indices
|
||||
decode_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
decode_metadata.block_table_bound = self.block_table_bound
|
||||
|
||||
# update the cache
|
||||
self._cached_decode_metadata = self.__class__(
|
||||
**decode_metadata.__dict__)
|
||||
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
||||
block_size: int, input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor) -> None:
|
||||
|
||||
ops.advance_step_flashinfer(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables,
|
||||
paged_kv_indices=self.paged_kv_indices,
|
||||
paged_kv_indptr=self.paged_kv_indptr,
|
||||
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
|
||||
block_table_bound=self.block_table_bound)
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
assert self.runner.model_config.max_model_len == 32768,\
|
||||
"AITER MLA requires max model len to be set to 32768"
|
||||
assert self.block_size == 1, "AITER MLA requires only block size 1."
|
||||
|
||||
def prepare(self):
|
||||
super().prepare()
|
||||
self.paged_kv_indices: list[int] = []
|
||||
self.paged_kv_indptr: list[int] = [0]
|
||||
self.paged_kv_last_page_lens: list[int] = []
|
||||
self.total_blocks = 0
|
||||
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block, input_positions) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
inter_data.input_positions):
|
||||
self.input_positions.extend(input_positions)
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
if is_profile_run:
|
||||
return
|
||||
|
||||
# Update paged_kv_* tensors only for non-profile run
|
||||
block_table = block_tables[seq_id]
|
||||
self._update_paged_kv_tensors(block_table, seq_len)
|
||||
|
||||
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
|
||||
# Get the number of valid blocks based on sequence length.
|
||||
# If seq_len = 16, block_size = 16,
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
self.total_blocks += len(block_table)
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
self.paged_kv_last_page_lens.append(last_page_len)
|
||||
|
||||
def build(self, seq_lens: list[int], query_lens: list[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
|
||||
metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
||||
batch_size)
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
if use_captured_graph:
|
||||
last_paged_kv_indptr = self.paged_kv_indptr[-1]
|
||||
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
||||
cuda_graph_pad_size)
|
||||
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
|
||||
|
||||
# For current version of AITER MLA
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
# extend to the maximum number of blocks as returned by the
|
||||
# scheduler
|
||||
self.paged_kv_indices.extend(
|
||||
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
paged_kv_last_page_lens_tensor = torch.tensor(
|
||||
self.paged_kv_last_page_lens, device=device, dtype=torch.int)
|
||||
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
||||
1,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
else:
|
||||
paged_kv_indices_tensor = None
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_lens_tensor = None
|
||||
block_table_bound_tensor = None
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr_tensor
|
||||
metadata.paged_kv_indices = paged_kv_indices_tensor
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
|
||||
metadata.block_table_bound = block_table_bound_tensor
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
|
||||
max_batch_size=max_batch_size,
|
||||
block_size=self.runner.block_size,
|
||||
max_block_per_batch=self.runner.get_max_block_per_batch(),
|
||||
device=self.runner.device)
|
||||
self._paged_kv_indices_tensor = kv_indices
|
||||
self._paged_kv_indptr_tensor = kv_indptr
|
||||
self._paged_kv_last_page_lens_tensor = last_page_lens
|
||||
|
||||
with super().graph_capture(max_batch_size):
|
||||
yield
|
||||
|
||||
del self._paged_kv_indices_tensor
|
||||
del self._paged_kv_indptr_tensor
|
||||
del self._paged_kv_last_page_lens_tensor
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:
|
||||
|
||||
metadata = super().graph_capture_get_metadata_for_batch(
|
||||
batch_size, is_encoder_decoder_model)
|
||||
|
||||
paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
|
||||
paged_kv_indices = self._paged_kv_indices_tensor
|
||||
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
|
||||
batch_size]
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr
|
||||
metadata.paged_kv_indices = paged_kv_indices
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
|
||||
|
||||
return metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = super().get_graph_input_buffers(
|
||||
attn_metadata, is_encoder_decoder_model)
|
||||
input_buffers[
|
||||
'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
|
||||
input_buffers[
|
||||
"paged_kv_indices"] = attn_metadata.\
|
||||
decode_metadata.paged_kv_indices
|
||||
input_buffers[
|
||||
"paged_kv_last_page_lens"] = attn_metadata.\
|
||||
decode_metadata.paged_kv_last_page_lens
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
||||
is_encoder_decoder_model)
|
||||
|
||||
num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
|
||||
0]
|
||||
input_buffers["paged_kv_indptr"].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
|
||||
input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
|
||||
input_buffers["paged_kv_last_page_lens"].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_last_page_lens,
|
||||
non_blocking=True)
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
softmax_scale: float, return_softmax_lse: bool,
|
||||
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_lens)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
42
vllm/attention/ops/rocm_aiter_mla.py
Normal file
42
vllm/attention/ops/rocm_aiter_mla.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||
max_block_per_batch: int,
|
||||
device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_indptr = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_last_page_lens = torch.full((max_batch_size, ),
|
||||
block_size,
|
||||
dtype=torch.int32)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
@ -1248,7 +1248,7 @@ class ModelConfig:
|
||||
or getattr(self.hf_config, "is_matryoshka", False))
|
||||
|
||||
|
||||
BlockSize = Literal[8, 16, 32, 64, 128]
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
|
||||
|
||||
|
||||
@ -79,6 +79,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
@ -558,6 +559,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Whether to use aiter mla ops.
|
||||
# By default is enabled.
|
||||
"VLLM_ROCM_USE_AITER_MLA":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
|
||||
("true", "1")),
|
||||
# use rocm skinny gemms
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
|
||||
|
||||
@ -39,6 +39,7 @@ class _Backend(enum.Enum):
|
||||
TRITON_ATTN_VLLM_V1 = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
ROCM_AITER_MLA = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
|
||||
@ -141,8 +141,36 @@ class RocmPlatform(Platform):
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
if use_mla:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
from vllm.attention.backends.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled)
|
||||
|
||||
if selected_backend is None:
|
||||
selected_backend = (_Backend.ROCM_AITER_MLA if
|
||||
is_aiter_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA)
|
||||
|
||||
if selected_backend == _Backend.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}.")
|
||||
elif selected_backend == _Backend.ROCM_AITER_MLA:
|
||||
if block_size == 1:
|
||||
logger.info("Using AITER MLA backend.")
|
||||
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
"(currently only supports block size 1)")
|
||||
else:
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend.")
|
||||
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if envs.VLLM_USE_V1:
|
||||
@ -317,4 +345,4 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(
|
||||
device_id).multi_processor_count
|
||||
device_id).multi_processor_count
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user