From 1b0a15553420e5459d9a8512a3f1bd7d4117db08 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 29 Jul 2025 23:50:46 -0400 Subject: [PATCH] [Perf] Using `__nv_fp8_e4m3` instead of `c10::e4m3` for `per_token_group_quant` (#21867) Signed-off-by: yewentao256 --- csrc/quantization/fp8/per_token_group_quant.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/fp8/per_token_group_quant.cu index 2609054f2072b..f5b40e35b6e5a 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/fp8/per_token_group_quant.cu @@ -1,12 +1,10 @@ #include -#include #include "../per_token_group_quant_8bit.h" #include -#include -#include +#include #include @@ -199,7 +197,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input, VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "per_token_group_quant_8bit", ([&] { if (dst_type == at::ScalarType::Float8_e4m3fn) { - LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); + LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); } else if (dst_type == at::ScalarType::Char) { LAUNCH_KERNEL(scalar_t, int8_t); }