mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[ROCm][Aiter] Add triton fp8 bmm kernel for mla (#23264)
Signed-off-by: Divakar Verma <divakar.verma@amd.com> Co-authored-by: ShaoChunLee <Shao-Chun.Lee@amd.com>
This commit is contained in:
parent
f32a5bc505
commit
04d1dd7f4a
@ -99,6 +99,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
VLLM_ROCM_USE_AITER_MHA: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
@ -774,6 +775,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Whether to use aiter triton fp8 bmm kernel
|
||||
# By default is enabled.
|
||||
"VLLM_ROCM_USE_AITER_FP8BMM":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# use rocm skinny gemms
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
|
||||
@ -1272,6 +1279,7 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_USE_AITER_RMSNORM",
|
||||
"VLLM_ROCM_USE_AITER_MLA",
|
||||
"VLLM_ROCM_USE_AITER_MHA",
|
||||
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM",
|
||||
"VLLM_ROCM_FP8_PADDING",
|
||||
"VLLM_ROCM_MOE_PADDING",
|
||||
|
||||
@ -193,6 +193,7 @@ from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
@ -203,6 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
@ -234,6 +236,28 @@ try:
|
||||
except ImportError:
|
||||
flashinfer_available = False
|
||||
|
||||
|
||||
def is_rocm_aiter_fp8bmm_enabled() -> bool:
|
||||
return current_platform.is_rocm() \
|
||||
and envs.VLLM_ROCM_USE_AITER_FP8BMM \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
|
||||
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
|
||||
as aiter_triton_fp8_bmm)
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CUDNN_WORKSPACE_SIZE = 12800
|
||||
@ -945,10 +969,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = aiter_triton_fp8_bmm(x,
|
||||
self.W_V,
|
||||
self.W_V_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return x
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
@ -996,6 +1031,46 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
W_K, dtype=current_platform.fp8_dtype())
|
||||
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
||||
W_V, dtype=current_platform.fp8_dtype())
|
||||
|
||||
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
||||
# triton kernel to avoid runtime compilation for unseen batch sizes
|
||||
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
||||
# On DS-R1, this step adds roughly 50s to the model loading time.
|
||||
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
||||
pre_compilation_list = list(range(1, max_batch_size + 1))
|
||||
if is_global_first_rank():
|
||||
pre_compilation_list = tqdm(
|
||||
pre_compilation_list,
|
||||
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
||||
total=max_batch_size,
|
||||
)
|
||||
|
||||
for m in pre_compilation_list:
|
||||
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device)
|
||||
aiter_triton_fp8_bmm(x,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
|
||||
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device)
|
||||
aiter_triton_fp8_bmm(x,
|
||||
self.W_V,
|
||||
self.W_V_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
else:
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
@ -1203,6 +1278,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
else:
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user