From 04d1dd7f4a444a61ae4b01ea0271490082dbd605 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Thu, 28 Aug 2025 13:18:08 -0500 Subject: [PATCH] [ROCm][Aiter] Add triton fp8 bmm kernel for mla (#23264) Signed-off-by: Divakar Verma Co-authored-by: ShaoChunLee --- vllm/envs.py | 8 ++ vllm/v1/attention/backends/mla/common.py | 108 ++++++++++++++++++++--- 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index a6a795dcfcda..1232bd7bf963 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -99,6 +99,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -774,6 +775,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")), + # Whether to use aiter triton fp8 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP8BMM": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in + ("true", "1")), + # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in @@ -1272,6 +1279,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_RMSNORM", "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", + "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_MOE_PADDING", diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ce45b34f6435..9f93b50b075b 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -193,6 +193,7 @@ from dataclasses import dataclass, field from typing import ClassVar, Generic, Optional, TypeVar, Union import torch +from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops @@ -203,6 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig +from vllm.distributed.parallel_state import is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -234,6 +236,28 @@ try: except ImportError: flashinfer_available = False + +def is_rocm_aiter_fp8bmm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_FP8BMM \ + and envs.VLLM_ROCM_USE_AITER + + +if is_rocm_aiter_fp8bmm_enabled(): + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + as aiter_triton_fp8_bmm) + + def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -945,10 +969,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return x def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -996,10 +1031,50 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, @@ -1203,10 +1278,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + else: + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) if fp8_attention: ql_nope_shape = decode_ql_nope.shape