[Kernel] Vectorized FP8 quantize kernel (#5396)

Inspired by #5146, this PR improves FP8 quantize kernel by vectorizing data transfer to better utilize memory bandwidth. Microbenchmark shows that this improved kernel can achieve 1.0x-1.5x speedup (especially when hidden size is large).

In details, we applied 3 optimizations:

- Use inverted scale so that most divisions are changed to multiplications.
- Unroll the loop by 4 times to improve ILP.
- Use vectorized 4 to transfer data between HBM and SRAM.
This commit is contained in:
Cody Yu 2024-06-12 14:07:26 -07:00 committed by GitHub
parent 8b82a89997
commit 5985e3427d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 6 deletions

View File

@ -23,8 +23,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
template <typename scalar_t>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float scale) {
float x = static_cast<float>(val) / scale;
const scalar_t val, const float inverted_scale) {
float x = static_cast<float>(val) * inverted_scale;
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r);
}
@ -71,15 +71,56 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
}
}
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
typedef struct __align__(4) {
c10::Float8_e4m3fn x;
c10::Float8_e4m3fn y;
c10::Float8_e4m3fn z;
c10::Float8_e4m3fn w;
}
float8x4_t;
template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
while (i < num_elems) {
out[i] = scaled_fp8_conversion(input[i], *scale);
i += blockDim.x * gridDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
// Vectorized input/output to better utilize memory bandwidth.
const vec4_t<scalar_t>* vectorized_in =
reinterpret_cast<const vec4_t<scalar_t>*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems;
i += blockDim.x * gridDim.x) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
}
}

View File

@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_fp8.py --forked`.
import pytest
import torch
from vllm._custom_ops import scaled_fp8_quant
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
@ -22,3 +23,49 @@ def test_load_fp16_model(vllm_runner) -> None:
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
@pytest.mark.skipif(
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:
def quantize_ref(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def per_tensor_dequantize(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 4 != 0 to cover edge cases,
# because scaled_fp8_quant is vectorized by 4.
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
# Dynamic quantization
ref_y, inv_scale = scaled_fp8_quant(x, None)
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
# Reference dynamic quantizaton
y = quantize_ref(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Static quantization
y, _ = scaled_fp8_quant(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Padding
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
ref_y,
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
dtype))