mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
[ROCm][Aiter] Add triton fp8 bmm kernel for mla (#23264)
Signed-off-by: Divakar Verma <divakar.verma@amd.com> Co-authored-by: ShaoChunLee <Shao-Chun.Lee@amd.com>
This commit is contained in:
parent
f32a5bc505
commit
04d1dd7f4a
@ -99,6 +99,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MHA: bool = True
|
VLLM_ROCM_USE_AITER_MHA: bool = True
|
||||||
|
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||||
VLLM_ROCM_FP8_PADDING: bool = True
|
VLLM_ROCM_FP8_PADDING: bool = True
|
||||||
VLLM_ROCM_MOE_PADDING: bool = True
|
VLLM_ROCM_MOE_PADDING: bool = True
|
||||||
@ -774,6 +775,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# Whether to use aiter triton fp8 bmm kernel
|
||||||
|
# By default is enabled.
|
||||||
|
"VLLM_ROCM_USE_AITER_FP8BMM":
|
||||||
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# use rocm skinny gemms
|
# use rocm skinny gemms
|
||||||
"VLLM_ROCM_USE_SKINNY_GEMM":
|
"VLLM_ROCM_USE_SKINNY_GEMM":
|
||||||
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
|
||||||
@ -1272,6 +1279,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_USE_AITER_RMSNORM",
|
"VLLM_ROCM_USE_AITER_RMSNORM",
|
||||||
"VLLM_ROCM_USE_AITER_MLA",
|
"VLLM_ROCM_USE_AITER_MLA",
|
||||||
"VLLM_ROCM_USE_AITER_MHA",
|
"VLLM_ROCM_USE_AITER_MHA",
|
||||||
|
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||||
"VLLM_ROCM_USE_SKINNY_GEMM",
|
"VLLM_ROCM_USE_SKINNY_GEMM",
|
||||||
"VLLM_ROCM_FP8_PADDING",
|
"VLLM_ROCM_FP8_PADDING",
|
||||||
"VLLM_ROCM_MOE_PADDING",
|
"VLLM_ROCM_MOE_PADDING",
|
||||||
|
|||||||
@ -193,6 +193,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@ -203,6 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims
|
|||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.parallel_state import is_global_first_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase,
|
LinearBase,
|
||||||
@ -234,6 +236,28 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
flashinfer_available = False
|
flashinfer_available = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_rocm_aiter_fp8bmm_enabled() -> bool:
|
||||||
|
return current_platform.is_rocm() \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER_FP8BMM \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER
|
||||||
|
|
||||||
|
|
||||||
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
|
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
|
||||||
|
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
|
||||||
|
as aiter_triton_fp8_bmm)
|
||||||
|
|
||||||
|
def dynamic_per_batched_tensor_quant(
|
||||||
|
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn):
|
||||||
|
DTYPE_MAX = torch.finfo(dtype).max
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||||
|
scale = DTYPE_MAX / amax
|
||||||
|
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||||
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
CUDNN_WORKSPACE_SIZE = 12800
|
CUDNN_WORKSPACE_SIZE = 12800
|
||||||
@ -945,10 +969,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
x = torch.bmm(x, self.W_UV)
|
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
x = aiter_triton_fp8_bmm(x,
|
||||||
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
self.W_V,
|
||||||
|
self.W_V_scale,
|
||||||
|
group_size=128,
|
||||||
|
transpose_bm=True)
|
||||||
|
# Convert from (B, N, V) to (B, N * V)
|
||||||
|
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
|
else:
|
||||||
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
|
x = torch.bmm(x, self.W_UV)
|
||||||
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
|
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
|
||||||
@ -996,10 +1031,50 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
W_UK, W_UV = kv_b_proj_weight.split(
|
W_UK, W_UV = kv_b_proj_weight.split(
|
||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
# Convert from (L, N, V) to (N, L, V)
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
self.W_UV = W_UV.transpose(0, 1)
|
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||||
# Convert from (L, N, P) to (N, P, L)
|
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||||
|
W_K, dtype=current_platform.fp8_dtype())
|
||||||
|
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
||||||
|
W_V, dtype=current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
||||||
|
# triton kernel to avoid runtime compilation for unseen batch sizes
|
||||||
|
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
||||||
|
# On DS-R1, this step adds roughly 50s to the model loading time.
|
||||||
|
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
||||||
|
pre_compilation_list = list(range(1, max_batch_size + 1))
|
||||||
|
if is_global_first_rank():
|
||||||
|
pre_compilation_list = tqdm(
|
||||||
|
pre_compilation_list,
|
||||||
|
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
||||||
|
total=max_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in pre_compilation_list:
|
||||||
|
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=self.W_K.device)
|
||||||
|
aiter_triton_fp8_bmm(x,
|
||||||
|
self.W_K,
|
||||||
|
self.W_K_scale,
|
||||||
|
group_size=128,
|
||||||
|
transpose_bm=True)
|
||||||
|
|
||||||
|
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=self.W_V.device)
|
||||||
|
aiter_triton_fp8_bmm(x,
|
||||||
|
self.W_V,
|
||||||
|
self.W_V_scale,
|
||||||
|
group_size=128,
|
||||||
|
transpose_bm=True)
|
||||||
|
else:
|
||||||
|
# Convert from (L, N, V) to (N, L, V)
|
||||||
|
self.W_UV = W_UV.transpose(0, 1)
|
||||||
|
# Convert from (L, N, P) to (N, P, L)
|
||||||
|
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||||
|
|
||||||
def _compute_prefill_context(
|
def _compute_prefill_context(
|
||||||
self,
|
self,
|
||||||
@ -1203,10 +1278,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
# Convert from (B, N, P) to (N, B, P)
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
|
||||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
# Convert from (N, B, L) to (B, N, L)
|
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
|
||||||
|
self.W_K,
|
||||||
|
self.W_K_scale,
|
||||||
|
group_size=128,
|
||||||
|
transpose_bm=True)
|
||||||
|
else:
|
||||||
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
|
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||||
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
|
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||||
|
|
||||||
if fp8_attention:
|
if fp8_attention:
|
||||||
ql_nope_shape = decode_ql_nope.shape
|
ql_nope_shape = decode_ql_nope.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user