[Kernel][Quantization] Integrate block-quantized CUTLASS kernels for DeepSeekV3 (#12587)

Integrates the block-quantized kernels introduced in
https://github.com/vllm-project/vllm/pull/11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-01-31 18:29:11 -05:00 committed by GitHub
parent 145c2ff648
commit eb5741ad42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 160 additions and 37 deletions

View File

@ -153,6 +153,7 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
#ifndef USE_ROCM #ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b, torch::Tensor const& a_scales,

View File

@ -58,7 +58,13 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(false, "Unsupported scale group shapes for CUTLASS 3.x GEMM"); TORCH_CHECK(false,
"Unsupported scale group shapes for CUTLASS 3.x GEMM.\n "
"a_scale_group_shape must be [1, 128], got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128], got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
} }
} }

View File

@ -81,6 +81,19 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return false; return false;
} }
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
@ -212,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"No compiled cutlass_scaled_mm_azp for a compute capability less than " "No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ", "CUDA device capability: ",
version_num); version_num);
} }

View File

@ -324,6 +324,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
ops.impl("cutlass_scaled_mm_supports_block_fp8",
&cutlass_scaled_mm_supports_fp8);
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the // Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability // given capability
ops.def( ops.def(

View File

@ -435,6 +435,11 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(
cuda_device_capability)
def cutlass_scaled_mm(a: torch.Tensor, def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,

View File

@ -21,7 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise, all_close_1d, apply_fp8_linear, convert_to_channelwise,
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale) requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter, from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
@ -133,6 +134,7 @@ class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
@ -359,6 +361,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale_inv, weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale, input_scale=layer.input_scale,
bias=bias, bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
) )
return apply_fp8_linear( return apply_fp8_linear(

View File

@ -8,6 +8,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -21,20 +22,34 @@ def apply_w8a8_block_fp8_linear(
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
output = w8a8_block_fp8_matmul(q_input, and weight.shape[1] % 128 == 0)
weight, if cutlass_block_fp8_supported and shape_supported_by_cutlass:
x_scale, q_input, x_scale = per_token_group_quant_fp8(input_2d,
weight_scale, block_size[1],
block_size, column_major_scales=True)
output_dtype=input.dtype) output = ops.cutlass_scaled_mm(q_input,
weight.T,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.T)
else:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=False)
output = w8a8_block_fp8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
@ -98,10 +113,7 @@ def _per_token_group_quant_fp8(
y_ptr, y_ptr,
y_q_ptr, y_q_ptr,
y_s_ptr, y_s_ptr,
# Stride of input group_size,
y_stride,
# Columns of input
N,
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for float8
@ -116,12 +128,60 @@ def _per_token_group_quant_fp8(
""" """
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0) g_id = tl.program_id(0)
y_ptr += g_id * y_stride y_ptr += g_id * group_size
y_q_ptr += g_id * y_stride y_q_ptr += g_id * group_size
y_s_ptr += g_id y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * group_size
y_q_ptr += g_id * group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
@ -138,12 +198,13 @@ def per_token_group_quant_fp8(
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`. """Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization. quantized tensor along with the scaling factor used for quantization.
Args: Args:
x: The input tenosr with ndim >= 2. x: The input tensor with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
@ -167,29 +228,46 @@ def per_token_group_quant_fp8(
x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
x_s = torch.empty( if column_major_scales:
x.shape[:-1] + (x.shape[-1] // group_size, ), shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
device=x.device, x_s = torch.empty(shape, device=x.device,
dtype=torch.float32, dtype=torch.float32).permute(-1, -2)
) else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N) BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8) num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1 num_stages = 1
_per_token_group_quant_fp8[(M, )]( if column_major_scales:
x, _per_token_group_quant_fp8_colmajor[(M, )](
x_q, x,
x_s, x_q,
group_size, x_s,
N, group_size,
eps, x.shape[1],
fp8_min=fp8_min, x_s.stride(1),
fp8_max=fp8_max, eps,
BLOCK=BLOCK, fp8_min=fp8_min,
num_warps=num_warps, fp8_max=fp8_max,
num_stages=num_stages, BLOCK=BLOCK,
) num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M, )](
x,
x_q,
x_s,
group_size,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s return x_q, x_s

View File

@ -30,6 +30,16 @@ def cutlass_fp8_supported() -> bool:
return ops.cutlass_scaled_mm_supports_fp8(capability) return ops.cutlass_scaled_mm_supports_fp8(capability)
def cutlass_block_fp8_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
def per_tensor_dequantize( def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float, tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor: torch.Tensor]) -> torch.Tensor: