[Perf] Using __nv_fp8_e4m3 instead of c10::e4m3 for per_token_group_quant (#21867)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-29 23:50:46 -04:00 committed by GitHub
parent 44bc46da60
commit 1b0a155534
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,12 +1,10 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/all.h>
@ -199,7 +197,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t);
}