diff --git a/CMakeLists.txt b/CMakeLists.txt index 767e9ad7541b0..98ed682fee7d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -245,7 +245,6 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" - "csrc/quantization/fp8/per_token_group_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" @@ -297,7 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/attention/mla/cutlass_mla_entry.cu") + "csrc/attention/mla/cutlass_mla_entry.cu" + "csrc/quantization/fp8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/ops.h b/csrc/ops.h index fdd3071c56ef2..97a247d9d628c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -287,6 +287,11 @@ void scaled_fp4_experts_quant( torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); + +void per_token_group_quant_fp8(const torch::Tensor& input, + torch::Tensor& output_q, torch::Tensor& output_s, + int64_t group_size, double eps, double fp8_min, + double fp8_max, bool scale_ue8m0); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, @@ -297,11 +302,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, std::optional const& azp); -void per_token_group_quant_fp8(const torch::Tensor& input, - torch::Tensor& output_q, torch::Tensor& output_s, - int64_t group_size, double eps, double fp8_min, - double fp8_max, bool scale_ue8m0); - torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d310211afe432..95f8541bc9e2d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -601,15 +601,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); - // Compute per-token-group FP8 quantized tensor and scaling factor. - ops.def( - "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! " - "output_s, " - "int group_size, float eps, float fp8_min, float fp8_max, bool " - "scale_ue8m0) -> ()"); - ops.impl("per_token_group_fp8_quant", torch::kCUDA, - &per_token_group_quant_fp8); - // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," @@ -624,6 +615,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); #ifndef USE_ROCM + // Compute per-token-group FP8 quantized tensor and scaling factor. + ops.def( + "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! " + "output_s, " + "int group_size, float eps, float fp8_min, float fp8_max, bool " + "scale_ue8m0) -> ()"); + ops.impl("per_token_group_fp8_quant", torch::kCUDA, + &per_token_group_quant_fp8); + // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel ops.def( "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "