vllm/csrc/quantization/fp8/common.cuh
Wentao Ye 4771df7b2b
[Feature] Non-contiguous Support for FP8 Quantization (#21961)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-08-05 02:36:43 -07:00

59 lines
1.5 KiB
Plaintext

#pragma once
#include "quantization/vectorization.cuh"
#include "quantization/utils.cuh"
#include <cmath>
#ifdef USE_ROCM
#include "amd/quant_utils.cuh"
#endif
// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static bool is_fp8_ocp() {
#ifndef USE_ROCM
return true;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx94");
return substring == std::string::npos;
#endif
}
namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
template <bool is_scale_inverted, typename fp8_type>
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r =
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM
return static_cast<fp8_type>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return fp8::cvt_c10<fp8_type>(r);
#endif
}
} // namespace vllm