From e502098643b863f09dd9d8e74249523641f01298 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 25 Nov 2025 09:59:07 -0500 Subject: [PATCH] [Kernel] Add NVFP4 MoE CUTLASS support for SM120 (#29242) Signed-off-by: mgoin Signed-off-by: Michael Goin --- CMakeLists.txt | 5 +- .../fp4/nvfp4_blockwise_moe_kernel.cu | 230 +++++++++++++++++- csrc/quantization/fp4/nvfp4_experts_quant.cu | 2 +- csrc/quantization/fp4/nvfp4_quant_entry.cu | 10 +- .../w8a8/cutlass/scaled_mm_entry.cu | 27 +- docs/design/moe_kernel_features.md | 2 +- .../compressed_tensors_moe.py | 14 +- .../layers/quantization/modelopt.py | 4 + 8 files changed, 264 insertions(+), 30 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 86746a0db4c0e..d88ba3aa66303 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -604,12 +604,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") else() message(STATUS "Not building NVFP4 as no compatible archs were found.") diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 5b007e5ea3283..6744402783832 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -22,6 +22,7 @@ #include #include #include +#include "cutlass_extensions/common.hpp" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" @@ -173,7 +174,7 @@ void run_get_group_gemm_starts( } template -void run_fp4_blockwise_scaled_group_mm( +void run_fp4_blockwise_scaled_group_mm_sm100( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes, @@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm( auto can_implement_status = gemm_op.can_implement(args); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM"); + "Failed to implement GEMM: status=", (int)can_implement_status); // Run the GEMM auto status = gemm_op.initialize(args, workspace.data_ptr()); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, " num_experts=", num_experts, + " M=", M, " N=", N, " K=", K); status = gemm_op.run(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } +void run_fp4_blockwise_scaled_group_mm_sm120( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + // NOTE: For SM120 it seems templating the output type is not supported and + // we need to hardcode the output type to bfloat16 + using ElementC = cutlass::bfloat16_t; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _128>; + + using FusionOperation = cutlass::epilogue::fusion::LinearCombination< + ElementD, ElementAccumulator, ElementC, ElementAccumulator>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutD*, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, + LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + torch::Tensor c_strides1 = + torch::full({num_experts}, output.stride(0), options_int); + torch::Tensor a_strides1 = + torch::full({num_experts}, a.stride(0) * 2, options_int); + torch::Tensor b_strides1 = + torch::full({num_experts}, b.stride(1) * 2, options_int); + + run_get_group_gemm_starts( + a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, + layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, + expert_offsets, sf_offsets, problem_sizes, M, N, K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides1.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides1.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides1.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides1.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = + reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + fusion_args.beta = 0.0f; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM: status=", (int)can_implement_status); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, " num_experts=", num_experts, + " M=", M, " N=", N, " K=", K); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 + if (version_num >= 120 && version_num < 130) { + run_fp4_blockwise_scaled_group_mm_sm120( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + return; + } +#endif #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 + if (version_num >= 100 && version_num < 120) { + run_fp4_blockwise_scaled_group_mm_sm100( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + return; + } +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ", + version_num, ". Required capability: 100 or 120"); +} + +#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ + (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; #endif @@ -374,7 +583,8 @@ void cutlass_fp4_group_mm( const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 +#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ + (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) // Input validation CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); @@ -408,6 +618,14 @@ void cutlass_fp4_group_mm( output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, expert_offsets, sf_offsets, M, N, K); } else { + #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 + int32_t version_num = get_sm_version_num(); + if (version_num >= 120 && version_num < 130) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ", + output.scalar_type()); + } + #endif run_fp4_blockwise_scaled_group_mm( output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, expert_offsets, sf_offsets, M, N, K); @@ -416,8 +634,8 @@ void cutlass_fp4_group_mm( TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_fp4_group_mm kernel, vLLM must " - "be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA " - "12.8 or above."); + "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 " + "and CUDA 12.8 or above."); #endif } diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 6d385e0dd94e7..82c53c2375a31 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -307,7 +307,7 @@ constexpr auto FLOAT = at::ScalarType::Float; constexpr auto INT = at::ScalarType::Int; constexpr auto UINT8 = at::ScalarType::Byte; -void scaled_fp4_experts_quant_sm100a( +void scaled_fp4_experts_quant_sm1xxa( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index c2b39e5438805..fb6d22f035b99 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -24,8 +24,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input_sf); #endif -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 -void scaled_fp4_experts_quant_sm100a( +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void scaled_fp4_experts_quant_sm1xxa( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, @@ -54,8 +55,9 @@ void scaled_fp4_experts_quant( torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 - return scaled_fp4_experts_quant_sm100a( +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return scaled_fp4_experts_quant_sm1xxa( output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); #endif diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 1001af05ff003..c5012a8669317 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif -#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ - defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \ - defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120 +#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \ + (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \ + (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120) void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -284,8 +284,9 @@ void get_cutlass_moe_mm_data( // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k, @@ -296,7 +297,7 @@ void get_cutlass_moe_mm_data( false, "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " "CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void get_cutlass_moe_mm_problem_sizes( @@ -304,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes( torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets) { int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets); @@ -315,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes( false, "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " "kernel for CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, @@ -328,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, num_local_experts, padded_m, n, k); @@ -339,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, false, "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " "for CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index f0d5a3e934f39..e54a9e2bc5e77 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes. - [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] -- [`CompressedTensorsW4A4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4MoeMethod] +- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] - [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 149e4419c64a4..71d7de97d4a10 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -103,7 +103,7 @@ __all__ = [ "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4MoeMethod", + "CompressedTensorsW4A4Nvfp4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod", ] @@ -171,7 +171,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): quant_config, layer.moe_config ) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4MoeMethod(layer.moe_config) + return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config) elif ( quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) @@ -188,7 +188,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ) -class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): +class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): def __init__(self, moe: FusedMoEConfig): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support, @@ -205,8 +205,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - " for CompressedTensorsW4A4MoeMethod." + " for CompressedTensorsW4A4Nvfp4MoeMethod." ) + elif self.use_marlin: + logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.") + else: + logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.") def create_weights( self, @@ -612,7 +616,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): assert expert_map is None, ( "Expert Parallelism / expert_map " "is currently not supported for " - "CompressedTensorsW4A4MoeMethod." + "CompressedTensorsW4A4Nvfp4MoeMethod." ) assert self.moe_quant_config is not None diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8165673135910..2cf7089e0ff90 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1132,6 +1132,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" " for ModelOptNvFp4FusedMoE." ) + elif self.use_marlin: + logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.") + else: + logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.") def maybe_make_prepare_finalize( self,