mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 20:17:05 +08:00
[FEAT] [ROCm]: Add AITER Block-Scaled GEMM Feature (#14968)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
0189a65a2e
commit
40de1ef455
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@ -16,6 +17,8 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -98,6 +101,34 @@ def test_enabled_ops_invalid(env: str):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
|
||||
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
|
||||
@pytest.mark.parametrize("use_cutlass", [True, False])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
|
||||
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
|
||||
use_rocm_aiter_gemm_w8a8_blockscale: str,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
|
||||
use_rocm_aiter_gemm_w8a8_blockscale)
|
||||
|
||||
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
|
||||
int(use_rocm_aiter_gemm_w8a8_blockscale)))
|
||||
block_scale_func = dispatch_w8a8_blockscale_func(
|
||||
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
|
||||
if use_cutlass:
|
||||
assert block_scale_func == cutlass_scaled_mm
|
||||
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_gemm_w8a8_blockscale):
|
||||
assert block_scale_func == (
|
||||
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
|
||||
else:
|
||||
assert block_scale_func == w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
|
||||
@ -182,6 +182,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
# AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
# and at the moment are MI300 series
|
||||
self.use_aiter_and_is_supported = (current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz())
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
# Default to using per_token quantization if cutlass is supported
|
||||
@ -402,6 +409,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -27,6 +27,76 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
|
||||
|
||||
def cutlass_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return ops.cutlass_scaled_mm(A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.T)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_w8a8_blockscale_func(
|
||||
use_cutlass: bool, use_aiter_and_is_supported: bool
|
||||
) -> Callable[[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
list[int],
|
||||
torch.dtype,
|
||||
], torch.Tensor]:
|
||||
if use_cutlass:
|
||||
return cutlass_scaled_mm
|
||||
if (use_aiter_and_is_supported):
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
|
||||
return w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
@ -37,26 +107,23 @@ def apply_w8a8_block_fp8_linear(
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
||||
and weight.shape[1] % 128 == 0)
|
||||
if current_platform.is_rocm():
|
||||
# TODO this is never used, as cutlass_block_fp8_supported is False
|
||||
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
|
||||
input_2d.shape[:-1])[::-1]
|
||||
scale_b_shape = (weight_scale.view(-1, 1)
|
||||
if weight_scale.dim() <= 1 else weight_scale.T).shape
|
||||
ar, ac = scale_a_shape
|
||||
br, bc = scale_b_shape
|
||||
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
|
||||
or br not in (1, weight.shape[0])):
|
||||
shape_supported_by_cutlass = False
|
||||
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||
if current_platform.is_cuda():
|
||||
use_cutlass = cutlass_block_fp8_supported and (
|
||||
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
else:
|
||||
use_cutlass = False
|
||||
|
||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||
use_cutlass, use_aiter_and_is_supported)
|
||||
|
||||
if use_cutlass:
|
||||
rows, cols = input_2d.shape
|
||||
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
|
||||
# optimal tensor core usage. Can be removed when targeting platforms
|
||||
@ -67,26 +134,22 @@ def apply_w8a8_block_fp8_linear(
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, 4 - (rows % 4)),
|
||||
value=0).contiguous()
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = ops.cutlass_scaled_mm(q_input,
|
||||
weight.T,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.T)
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
if should_pad:
|
||||
output = output[:rows, :]
|
||||
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=False)
|
||||
output = w8a8_block_fp8_matmul(q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=input.dtype)
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
@ -98,6 +161,9 @@ def apply_w8a8_block_fp8_linear_fake(
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user