[Performance][ROCm] Add skinny gemms for unquantized linear on ROCm (#15830)

Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
Charlie Fu 2025-04-21 22:46:22 -05:00 committed by GitHub
parent b9b4746950
commit 188b7f9b8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1955 additions and 93 deletions

View File

@ -678,6 +678,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
#
set(VLLM_ROCM_EXT_SRC
"csrc/rocm/torch_bindings.cpp"
"csrc/rocm/skinny_gemms.cu"
"csrc/rocm/attention.cu")
define_gpu_extension_target(

View File

@ -2,6 +2,15 @@
#include <torch/all.h>
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const int64_t CuCount);
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,

1600
csrc/rocm/skinny_gemms.cu Normal file

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,24 @@
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// vLLM custom ops for rocm
// Custom gemm op for matrix-vector multiplication
rocm_ops.def(
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
"Tensor");
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// wvSplitK for fp8
rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
// Custom attention op
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.

View File

@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0
N = [1, 2, 3, 4]
SEEDS = [0]
@pytest.mark.parametrize("n", [1]) # only test for batch size 1
@pytest.mark.parametrize("k", K)
@pytest.mark.parametrize("m", M)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.manual_seed(seed)
A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")
ref_out = torch.matmul(A, B.t())
out = ops.LLMM1(B, A, rows_per_block)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
@pytest.mark.parametrize("m", [8] + M) # m >= 8
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")
ref_out = torch.matmul(A, B.t())
out = ops.wvSplitK(B, A, cu_count)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda")
B = torch.rand(m, k, device="cuda")
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count())
assert torch.allclose(out, ref_out, rtol=0.01)

View File

@ -1196,6 +1196,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states, pad_slot_id)
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor,
rows_per_block: int) -> torch.Tensor:
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cu_count: int) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype,
device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
return out
# moe
def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output)

View File

@ -78,6 +78,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
@ -550,6 +551,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
("true", "1")),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),

View File

@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
@ -188,7 +188,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
return dispatch_unquantized_gemm()(x, layer.weight, bias)
class LinearBase(torch.nn.Module):

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
@ -17,6 +18,7 @@ TORCH_DEVICE_IDENTITY = None
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
and torch.__version__[0:3] >= "2.7"
and current_platform.has_device_capability(94))
@ -131,6 +133,159 @@ def maybe_create_device_identity():
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: List, **kwargs) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return output.view(*output_shape)
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool,
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
) -> Callable[..., torch.Tensor]:
if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
return torch_per_token_w8a8_scaled_mm
return torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp:
@ -156,7 +311,8 @@ class Fp8LinearOp:
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE
self.output_padding = 17 if pad_output else None
self.output_padding = 17 if (
pad_output and not current_platform.is_rocm()) else None
def apply(
self,
@ -195,18 +351,6 @@ class Fp8LinearOp:
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__
@ -218,84 +362,21 @@ class Fp8LinearOp:
else:
qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.cutlass_fp8_supported, per_tensor_weights,
per_tensor_activations, use_per_token_if_dynamic)
return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)
elif (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
input_2d=input_2d,
output_shape=output_shape)
def normalize_e4m3fn_to_e4m3fnuz(

View File

@ -1,9 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
"""Utility methods for model layers."""
from typing import Tuple
from typing import Callable, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
@ -61,3 +65,34 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and n < 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, out, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
if current_platform.is_rocm():
return rocm_unquantized_gemm
return torch.nn.functional.linear

View File

@ -413,6 +413,13 @@ class Platform:
self.device_name, key)
return None
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
"""
Returns the total number of compute units (CU) on single GPU.
"""
raise NotImplementedError
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -312,3 +312,8 @@ class RocmPlatform(Platform):
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94']
return any(gfx in gcn_arch for gfx in supported_archs)
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(
device_id).multi_processor_count