[ROCm][Aiter] Add triton fp8 bmm kernel for mla (#23264)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
Co-authored-by: ShaoChunLee <Shao-Chun.Lee@amd.com>
This commit is contained in:
Divakar Verma 2025-08-28 13:18:08 -05:00 committed by GitHub
parent f32a5bc505
commit 04d1dd7f4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 104 additions and 12 deletions

View File

@ -99,6 +99,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
@ -774,6 +775,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in
("true", "1")),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
@ -1272,6 +1279,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING",

View File

@ -193,6 +193,7 @@ from dataclasses import dataclass, field
from typing import ClassVar, Generic, Optional, TypeVar, Union
import torch
from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops
@ -203,6 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
@ -234,6 +236,28 @@ try:
except ImportError:
flashinfer_available = False
def is_rocm_aiter_fp8bmm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_FP8BMM \
and envs.VLLM_ROCM_USE_AITER
if is_rocm_aiter_fp8bmm_enabled():
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
as aiter_triton_fp8_bmm)
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
logger = init_logger(__name__)
CUDNN_WORKSPACE_SIZE = 12800
@ -945,10 +969,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(x,
self.W_V,
self.W_V_scale,
group_size=128,
transpose_bm=True)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x
def process_weights_after_loading(self, act_dtype: torch.dtype):
@ -996,6 +1031,46 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if is_rocm_aiter_fp8bmm_enabled():
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype())
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
W_V, dtype=current_platform.fp8_dtype())
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
pre_compilation_list = list(range(1, max_batch_size + 1))
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
total=max_batch_size,
)
for m in pre_compilation_list:
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
dtype=torch.bfloat16,
device=self.W_K.device)
aiter_triton_fp8_bmm(x,
self.W_K,
self.W_K_scale,
group_size=128,
transpose_bm=True)
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
dtype=torch.bfloat16,
device=self.W_V.device)
aiter_triton_fp8_bmm(x,
self.W_V,
self.W_V_scale,
group_size=128,
transpose_bm=True)
else:
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
@ -1203,6 +1278,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
self.W_K,
self.W_K_scale,
group_size=128,
transpose_bm=True)
else:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)