diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/fp8/per_token_group_quant.cu index 2609054f2072b..f5b40e35b6e5a 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/fp8/per_token_group_quant.cu @@ -1,12 +1,10 @@ #include -#include #include "../per_token_group_quant_8bit.h" #include -#include -#include +#include #include @@ -199,7 +197,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input, VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "per_token_group_quant_8bit", ([&] { if (dst_type == at::ScalarType::Float8_e4m3fn) { - LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); + LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); } else if (dst_type == at::ScalarType::Char) { LAUNCH_KERNEL(scalar_t, int8_t); }