mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[ Kernel ] FP8 Dynamic-Per-Token Quant Kernel (#6511)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
e76466dde2
commit
b5241e41d9
10
csrc/ops.h
10
csrc/ops.h
@ -128,12 +128,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
|
||||
@ -7,6 +7,8 @@
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "../../reduction_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
@ -88,25 +90,48 @@ typedef struct __align__(4) {
|
||||
float8x4_t;
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// Invert the scale so that we can use multiplications to avoid expensive
|
||||
// division.
|
||||
const float inverted_scale = 1.0f / (*scale);
|
||||
|
||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
const vec4_t<scalar_t>* vectorized_in =
|
||||
reinterpret_cast<const vec4_t<scalar_t>*>(input);
|
||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
|
||||
int num_vec_elems = num_elems >> 2;
|
||||
int const num_vec_elems = num_elems >> 2;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
|
||||
for (int i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
absmax_val = max(absmax_val, fabs(input[i]));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const inverted_scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
|
||||
int const num_vec_elems = num_elems >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
@ -118,17 +143,74 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int i = num_vec_elems * 4 + tid; i < num_elems;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// Invert the scale so that we can use multiplications to avoid expensive
|
||||
// division.
|
||||
const float inverted_scale = 1.0f / (*scale);
|
||||
|
||||
scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
|
||||
blockDim.x * gridDim.x);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
||||
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 4 == 0;
|
||||
|
||||
float absmax_val = 0.0f;
|
||||
if (can_vectorize) {
|
||||
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float const x = static_cast<float>(token_input[i]);
|
||||
absmax_val = max(absmax_val, fabs(x));
|
||||
}
|
||||
}
|
||||
|
||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||
__shared__ float block_absmax_val;
|
||||
if (tid == 0) {
|
||||
block_absmax_val = block_absmax_val_maybe;
|
||||
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
|
||||
if (can_vectorize) {
|
||||
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
|
||||
hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
@ -144,9 +226,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
@ -163,3 +245,25 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(), hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
@ -179,12 +179,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
||||
|
||||
// Compute FP8 quantized tensor and scaling factor.
|
||||
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||
"()");
|
||||
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
||||
|
||||
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
|
||||
"scale) -> "
|
||||
"()");
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||
&dynamic_per_token_scaled_fp8_quant);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size.
|
||||
ops.def(
|
||||
|
||||
56
tests/kernels/quant_utils.py
Normal file
56
tests/kernels/quant_utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
||||
|
||||
def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
quant_dtype: torch.dtype) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
|
||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||
else torch.finfo(quant_dtype)
|
||||
qtype_max = as_float32_tensor(qtype_traits.max)
|
||||
|
||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||
# the same operations as in the corresponding fp8 kernel to prevent
|
||||
# rounding errors.
|
||||
|
||||
# Compute scales
|
||||
x_token_max, _ = x.abs().max(dim=-1)
|
||||
x_token_max = as_float32_tensor(x_token_max)
|
||||
scales = (x_token_max / qtype_max)[:, None]
|
||||
|
||||
# Quant
|
||||
iscales = (qtype_max / x_token_max)[:, None]
|
||||
torch_out = as_float32_tensor(x) * iscales
|
||||
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
|
||||
return torch_out, scales
|
||||
|
||||
|
||||
# The int8 version is very similar. Incorporate the int8 version, like in
|
||||
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
||||
# kernel
|
||||
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
fp8_traits = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = as_float32_tensor(fp8_traits.max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||
# the same operations as in the corresponding fp8 kernel to prevent
|
||||
# rounding errors.
|
||||
|
||||
x_max = as_float32_tensor(x.abs().max())
|
||||
ref_scale = x_max / fp8_max
|
||||
ref_iscale = one / ref_scale
|
||||
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
||||
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
|
||||
return ref_out, ref_scale
|
||||
54
tests/kernels/test_fp8_quant.py
Normal file
54
tests/kernels/test_fp8_quant.py
Normal file
@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant)
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
8193] # Arbitrary values for testing
|
||||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
|
||||
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
|
||||
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
||||
|
||||
assert torch.allclose(ref_scale, ops_scale)
|
||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
@ -3,6 +3,8 @@ import torch
|
||||
|
||||
# ruff: noqa: F401
|
||||
import vllm._C
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
@ -21,23 +23,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
x_token_max, _ = x.max(dim=1)
|
||||
x_token_max = x_token_max.to(dtype=torch.float32)
|
||||
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
|
||||
dtype=torch.float32)
|
||||
torch_out = (x / scales).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
# reference
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
|
||||
# kernel
|
||||
ops_out, ops_scales = scaled_int8_quant(x)
|
||||
|
||||
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
|
||||
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
|
||||
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
|
||||
|
||||
assert torch.allclose(scales_out, scales)
|
||||
assert torch.allclose(torch_out, ops_out,
|
||||
assert torch.allclose(ops_scales, ref_scales)
|
||||
assert torch.allclose(ops_out, ref_out,
|
||||
atol=1) # big atol to account for rounding errors
|
||||
|
||||
|
||||
@ -55,12 +50,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
out1 = (x / scale).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
out2, _ = scaled_int8_quant(x, scale)
|
||||
|
||||
torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
|
||||
assert torch.allclose(out1, out2,
|
||||
atol=1) # big atol to account for rounding errors
|
||||
|
||||
@ -335,6 +335,17 @@ def scaled_fp8_quant(
|
||||
return output, scale
|
||||
|
||||
|
||||
def dynamic_per_token_scaled_fp8_quant(
|
||||
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||
scales = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
|
||||
return output, scales
|
||||
|
||||
|
||||
# int8
|
||||
def scaled_int8_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user