[ROCm][Aiter] Add triton fp8 bmm kernel for mla (#23264)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
Co-authored-by: ShaoChunLee <Shao-Chun.Lee@amd.com>
This commit is contained in:
Divakar Verma 2025-08-28 13:18:08 -05:00 committed by GitHub
parent f32a5bc505
commit 04d1dd7f4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 104 additions and 12 deletions

View File

@ -99,6 +99,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
@ -774,6 +775,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")), ("true", "1")),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in
("true", "1")),
# use rocm skinny gemms # use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": "VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
@ -1272,6 +1279,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM", "VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA", "VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING", "VLLM_ROCM_MOE_PADDING",

View File

@ -193,6 +193,7 @@ from dataclasses import dataclass, field
from typing import ClassVar, Generic, Optional, TypeVar, Union from typing import ClassVar, Generic, Optional, TypeVar, Union
import torch import torch
from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
@ -203,6 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.parallel_state import is_global_first_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, LinearBase,
@ -234,6 +236,28 @@ try:
except ImportError: except ImportError:
flashinfer_available = False flashinfer_available = False
def is_rocm_aiter_fp8bmm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_FP8BMM \
and envs.VLLM_ROCM_USE_AITER
if is_rocm_aiter_fp8bmm_enabled():
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
as aiter_triton_fp8_bmm)
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
logger = init_logger(__name__) logger = init_logger(__name__)
CUDNN_WORKSPACE_SIZE = 12800 CUDNN_WORKSPACE_SIZE = 12800
@ -945,10 +969,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
def _v_up_proj(self, x): def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V) if is_rocm_aiter_fp8bmm_enabled():
x = torch.bmm(x, self.W_UV) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
# Convert from (N, B, V) to (B, N * V) x = aiter_triton_fp8_bmm(x,
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) self.W_V,
self.W_V_scale,
group_size=128,
transpose_bm=True)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
@ -996,10 +1031,50 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V) if is_rocm_aiter_fp8bmm_enabled():
self.W_UV = W_UV.transpose(0, 1) W_K = W_UK.transpose(0, 1) # 16 512 128
# Convert from (L, N, P) to (N, P, L) W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_UK_T = W_UK.permute(1, 2, 0) self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype())
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
W_V, dtype=current_platform.fp8_dtype())
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
pre_compilation_list = list(range(1, max_batch_size + 1))
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
total=max_batch_size,
)
for m in pre_compilation_list:
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
dtype=torch.bfloat16,
device=self.W_K.device)
aiter_triton_fp8_bmm(x,
self.W_K,
self.W_K_scale,
group_size=128,
transpose_bm=True)
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
dtype=torch.bfloat16,
device=self.W_V.device)
aiter_triton_fp8_bmm(x,
self.W_V,
self.W_V_scale,
group_size=128,
transpose_bm=True)
else:
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _compute_prefill_context( def _compute_prefill_context(
self, self,
@ -1203,10 +1278,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P) # Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1) decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) if is_rocm_aiter_fp8bmm_enabled():
# Convert from (N, B, L) to (B, N, L) # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
self.W_K,
self.W_K_scale,
group_size=128,
transpose_bm=True)
else:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
if fp8_attention: if fp8_attention:
ql_nope_shape = decode_ql_nope.shape ql_nope_shape = decode_ql_nope.shape