mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[Perf] Use NVIDIA hardware-accelerated instruction for float to fp8_e4m3 quantization (#24757)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
30498f2a65
commit
dbeee3844c
@ -5,7 +5,9 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#ifndef USE_ROCM
|
||||
#include "nvidia/quant_utils.cuh"
|
||||
#else
|
||||
#include "amd/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float r =
|
||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<fp8_type>(r);
|
||||
// Use hardware cvt instruction for fp8 on nvidia
|
||||
// Currently only support fp8_type = c10::Float8_e4m3fn
|
||||
return fp8::vec_conversion<fp8_type, float>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
|
||||
@ -12,13 +12,26 @@ namespace vllm {
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout
|
||||
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__inline__ __device__ Tout vec_conversion(
|
||||
const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// float -> c10::Float8_e4m3fn
|
||||
template <>
|
||||
__inline__ __device__ c10::Float8_e4m3fn
|
||||
vec_conversion<c10::Float8_e4m3fn, float>(
|
||||
const float& a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return static_cast<c10::Float8_e4m3fn>(a);
|
||||
#else
|
||||
return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user