mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:45:01 +08:00
[Kernel] Vectorized FP8 quantize kernel (#5396)
Inspired by #5146, this PR improves FP8 quantize kernel by vectorizing data transfer to better utilize memory bandwidth. Microbenchmark shows that this improved kernel can achieve 1.0x-1.5x speedup (especially when hidden size is large). In details, we applied 3 optimizations: - Use inverted scale so that most divisions are changed to multiplications. - Unroll the loop by 4 times to improve ILP. - Use vectorized 4 to transfer data between HBM and SRAM.
This commit is contained in:
parent
8b82a89997
commit
5985e3427d
@ -23,8 +23,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||||
const scalar_t val, const float scale) {
|
const scalar_t val, const float inverted_scale) {
|
||||||
float x = static_cast<float>(val) / scale;
|
float x = static_cast<float>(val) * inverted_scale;
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
}
|
}
|
||||||
@ -71,15 +71,56 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
struct __align__(8) vec4_t {
|
||||||
|
scalar_t x;
|
||||||
|
scalar_t y;
|
||||||
|
scalar_t z;
|
||||||
|
scalar_t w;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct __align__(4) {
|
||||||
|
c10::Float8_e4m3fn x;
|
||||||
|
c10::Float8_e4m3fn y;
|
||||||
|
c10::Float8_e4m3fn z;
|
||||||
|
c10::Float8_e4m3fn w;
|
||||||
|
}
|
||||||
|
float8x4_t;
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
const float* __restrict__ scale,
|
const float* __restrict__ scale,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
while (i < num_elems) {
|
|
||||||
out[i] = scaled_fp8_conversion(input[i], *scale);
|
// Invert the scale so that we can use multiplications to avoid expensive
|
||||||
i += blockDim.x * gridDim.x;
|
// division.
|
||||||
|
const float inverted_scale = 1.0f / (*scale);
|
||||||
|
|
||||||
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
|
const vec4_t<scalar_t>* vectorized_in =
|
||||||
|
reinterpret_cast<const vec4_t<scalar_t>*>(input);
|
||||||
|
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||||
|
|
||||||
|
int num_vec_elems = num_elems >> 2;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
|
||||||
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||||
|
float8x4_t out_vec;
|
||||||
|
|
||||||
|
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
|
||||||
|
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
|
||||||
|
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
|
||||||
|
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
|
||||||
|
vectorized_out[i] = out_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the remaining elements if num_elems is not divisible by 4
|
||||||
|
for (int i = num_vec_elems * 4 + tid; i < num_elems;
|
||||||
|
i += blockDim.x * gridDim.x) {
|
||||||
|
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_fp8.py --forked`.
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
|
|
||||||
@ -22,3 +23,49 @@ def test_load_fp16_model(vllm_runner) -> None:
|
|||||||
fc1 = model.model.decoder.layers[0].fc1
|
fc1 = model.model.decoder.layers[0].fc1
|
||||||
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
||||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
|
||||||
|
reason="FP8 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
def test_scaled_fp8_quant(dtype) -> None:
|
||||||
|
|
||||||
|
def quantize_ref(tensor, inv_scale):
|
||||||
|
# The reference implementation that fully aligns to
|
||||||
|
# the kernel being tested.
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
scale = inv_scale.reciprocal()
|
||||||
|
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
|
||||||
|
max=finfo.max)
|
||||||
|
qweight = qweight.to(torch.float8_e4m3fn)
|
||||||
|
return qweight
|
||||||
|
|
||||||
|
def per_tensor_dequantize(tensor, inv_scale, dtype):
|
||||||
|
fake_qweight = tensor.to(dtype)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
# Note that we use a shape % 4 != 0 to cover edge cases,
|
||||||
|
# because scaled_fp8_quant is vectorized by 4.
|
||||||
|
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
||||||
|
|
||||||
|
# Dynamic quantization
|
||||||
|
ref_y, inv_scale = scaled_fp8_quant(x, None)
|
||||||
|
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
|
||||||
|
|
||||||
|
# Reference dynamic quantizaton
|
||||||
|
y = quantize_ref(x, inv_scale)
|
||||||
|
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||||
|
|
||||||
|
# Static quantization
|
||||||
|
y, _ = scaled_fp8_quant(x, inv_scale)
|
||||||
|
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||||
|
|
||||||
|
# Padding
|
||||||
|
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
||||||
|
assert y.shape[0] == 17
|
||||||
|
assert torch.allclose(
|
||||||
|
ref_y,
|
||||||
|
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
|
||||||
|
dtype))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user