mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[Kernel][Quantization] Integrate block-quantized CUTLASS kernels for DeepSeekV3 (#12587)
Integrates the block-quantized kernels introduced in https://github.com/vllm-project/vllm/pull/11868 for use in linear layers. Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
145c2ff648
commit
eb5741ad42
@ -153,6 +153,7 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
|||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||||
|
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
|
|||||||
@ -58,7 +58,13 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
|
|
||||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported scale group shapes for CUTLASS 3.x GEMM");
|
TORCH_CHECK(false,
|
||||||
|
"Unsupported scale group shapes for CUTLASS 3.x GEMM.\n "
|
||||||
|
"a_scale_group_shape must be [1, 128], got: [",
|
||||||
|
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||||
|
"]\n"
|
||||||
|
"b_scale_group_shape must be [128, 128], got: [",
|
||||||
|
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -81,6 +81,19 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||||
|
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||||
|
// and at least SM90 (Hopper)
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION
|
||||||
|
if (cuda_device_capability >= 90) {
|
||||||
|
return CUDA_VERSION >= 12000;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
|
|||||||
@ -324,6 +324,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||||
|
|
||||||
|
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||||
|
"bool");
|
||||||
|
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||||
|
&cutlass_scaled_mm_supports_fp8);
|
||||||
|
|
||||||
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
||||||
// given capability
|
// given capability
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@ -435,6 +435,11 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
|||||||
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
|
||||||
|
return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(
|
||||||
|
cuda_device_capability)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_scaled_mm(a: torch.Tensor,
|
def cutlass_scaled_mm(a: torch.Tensor,
|
||||||
b: torch.Tensor,
|
b: torch.Tensor,
|
||||||
scale_a: torch.Tensor,
|
scale_a: torch.Tensor,
|
||||||
|
|||||||
@ -21,7 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||||
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||||
requantize_with_max_scale)
|
requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
@ -133,6 +134,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||||
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
@ -359,6 +361,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_scale=layer.weight_scale_inv,
|
weight_scale=layer.weight_scale_inv,
|
||||||
input_scale=layer.input_scale,
|
input_scale=layer.input_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||||
)
|
)
|
||||||
|
|
||||||
return apply_fp8_linear(
|
return apply_fp8_linear(
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -21,20 +22,34 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
cutlass_block_fp8_supported: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert input_scale is None
|
assert input_scale is None
|
||||||
# View input as 2D matrix for fp8 methods
|
# View input as 2D matrix for fp8 methods
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
|
|
||||||
q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
|
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
||||||
|
and weight.shape[1] % 128 == 0)
|
||||||
|
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||||
|
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||||
|
block_size[1],
|
||||||
|
column_major_scales=True)
|
||||||
|
output = ops.cutlass_scaled_mm(q_input,
|
||||||
|
weight.T,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale.T)
|
||||||
|
else:
|
||||||
|
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||||
|
block_size[1],
|
||||||
|
column_major_scales=False)
|
||||||
output = w8a8_block_fp8_matmul(q_input,
|
output = w8a8_block_fp8_matmul(q_input,
|
||||||
weight,
|
weight,
|
||||||
x_scale,
|
x_scale,
|
||||||
weight_scale,
|
weight_scale,
|
||||||
block_size,
|
block_size,
|
||||||
output_dtype=input.dtype)
|
output_dtype=input.dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = output + bias
|
output = output + bias
|
||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
@ -98,10 +113,7 @@ def _per_token_group_quant_fp8(
|
|||||||
y_ptr,
|
y_ptr,
|
||||||
y_q_ptr,
|
y_q_ptr,
|
||||||
y_s_ptr,
|
y_s_ptr,
|
||||||
# Stride of input
|
group_size,
|
||||||
y_stride,
|
|
||||||
# Columns of input
|
|
||||||
N,
|
|
||||||
# Avoid to divide zero
|
# Avoid to divide zero
|
||||||
eps,
|
eps,
|
||||||
# Information for float8
|
# Information for float8
|
||||||
@ -116,12 +128,60 @@ def _per_token_group_quant_fp8(
|
|||||||
"""
|
"""
|
||||||
# Map the program id to the row of X and Y it should compute.
|
# Map the program id to the row of X and Y it should compute.
|
||||||
g_id = tl.program_id(0)
|
g_id = tl.program_id(0)
|
||||||
y_ptr += g_id * y_stride
|
y_ptr += g_id * group_size
|
||||||
y_q_ptr += g_id * y_stride
|
y_q_ptr += g_id * group_size
|
||||||
y_s_ptr += g_id
|
y_s_ptr += g_id
|
||||||
|
|
||||||
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
||||||
mask = cols < N
|
mask = cols < group_size
|
||||||
|
|
||||||
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
# Quant
|
||||||
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||||
|
y_s = _absmax / fp8_max
|
||||||
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||||
|
tl.store(y_s_ptr, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _per_token_group_quant_fp8_colmajor(
|
||||||
|
# Pointers to inputs and output
|
||||||
|
y_ptr,
|
||||||
|
y_q_ptr,
|
||||||
|
y_s_ptr,
|
||||||
|
group_size,
|
||||||
|
# Num columns of y
|
||||||
|
y_num_columns,
|
||||||
|
# Stride from one column to the next of y_s
|
||||||
|
y_s_col_stride,
|
||||||
|
# Avoid to divide zero
|
||||||
|
eps,
|
||||||
|
# Information for float8
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""A Triton-accelerated function to perform per-token-group
|
||||||
|
quantization on a tensor.
|
||||||
|
This function converts the tensor values into float8 values.
|
||||||
|
"""
|
||||||
|
# Map the program id to the row of X and Y it should compute.
|
||||||
|
g_id = tl.program_id(0)
|
||||||
|
y_ptr += g_id * group_size
|
||||||
|
y_q_ptr += g_id * group_size
|
||||||
|
|
||||||
|
# Convert g_id the flattened block coordinate to 2D so we can index
|
||||||
|
# into the output y_scales matrix
|
||||||
|
blocks_per_row = y_num_columns // group_size
|
||||||
|
scale_col = g_id % blocks_per_row
|
||||||
|
scale_row = g_id // blocks_per_row
|
||||||
|
y_s_ptr += scale_col * y_s_col_stride + scale_row
|
||||||
|
|
||||||
|
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
||||||
|
mask = cols < group_size
|
||||||
|
|
||||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||||
# Quant
|
# Quant
|
||||||
@ -138,12 +198,13 @@ def per_token_group_quant_fp8(
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
column_major_scales: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||||
It converts the tensor values into signed float8 values and returns the
|
It converts the tensor values into signed float8 values and returns the
|
||||||
quantized tensor along with the scaling factor used for quantization.
|
quantized tensor along with the scaling factor used for quantization.
|
||||||
Args:
|
Args:
|
||||||
x: The input tenosr with ndim >= 2.
|
x: The input tensor with ndim >= 2.
|
||||||
group_size: The group size used for quantization.
|
group_size: The group size used for quantization.
|
||||||
eps: The minimum to avoid dividing zero.
|
eps: The minimum to avoid dividing zero.
|
||||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||||
@ -167,22 +228,39 @@ def per_token_group_quant_fp8(
|
|||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
M = x.numel() // group_size
|
M = x.numel() // group_size
|
||||||
N = group_size
|
N = group_size
|
||||||
x_s = torch.empty(
|
if column_major_scales:
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size, ),
|
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
|
||||||
device=x.device,
|
x_s = torch.empty(shape, device=x.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32).permute(-1, -2)
|
||||||
)
|
else:
|
||||||
|
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
|
||||||
|
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
BLOCK = triton.next_power_of_2(N)
|
BLOCK = triton.next_power_of_2(N)
|
||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
if column_major_scales:
|
||||||
|
_per_token_group_quant_fp8_colmajor[(M, )](
|
||||||
|
x,
|
||||||
|
x_q,
|
||||||
|
x_s,
|
||||||
|
group_size,
|
||||||
|
x.shape[1],
|
||||||
|
x_s.stride(1),
|
||||||
|
eps,
|
||||||
|
fp8_min=fp8_min,
|
||||||
|
fp8_max=fp8_max,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
_per_token_group_quant_fp8[(M, )](
|
_per_token_group_quant_fp8[(M, )](
|
||||||
x,
|
x,
|
||||||
x_q,
|
x_q,
|
||||||
x_s,
|
x_s,
|
||||||
group_size,
|
group_size,
|
||||||
N,
|
|
||||||
eps,
|
eps,
|
||||||
fp8_min=fp8_min,
|
fp8_min=fp8_min,
|
||||||
fp8_max=fp8_max,
|
fp8_max=fp8_max,
|
||||||
|
|||||||
@ -30,6 +30,16 @@ def cutlass_fp8_supported() -> bool:
|
|||||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_block_fp8_supported() -> bool:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
|
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||||
|
|
||||||
|
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_dequantize(
|
def per_tensor_dequantize(
|
||||||
tensor: torch.Tensor, inv_scale: Union[float,
|
tensor: torch.Tensor, inv_scale: Union[float,
|
||||||
torch.Tensor]) -> torch.Tensor:
|
torch.Tensor]) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user