From f6227c22ab8976a24913122874c24624102da1b4 Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Mon, 8 Dec 2025 22:29:06 -0500 Subject: [PATCH] [Kernel]Support W4A8 Grouped GEMM on Hopper (#29691) Signed-off-by: czhu-cohere --- CMakeLists.txt | 5 +- csrc/ops.h | 3 +- .../cutlass_w4a8/get_group_starts.cuh | 104 ++++ .../cutlass_w4a8/w4a8_grouped_mm_entry.cu | 483 ++++++++++++++++++ .../cutlass_w4a8/w4a8_mm_entry.cu | 70 +-- csrc/quantization/cutlass_w4a8/w4a8_utils.cu | 90 ++++ csrc/quantization/cutlass_w4a8/w4a8_utils.cuh | 11 + .../quantization/w8a8/cutlass/moe/moe_data.cu | 8 +- .../w8a8/cutlass/scaled_mm_entry.cu | 8 +- csrc/torch_bindings.cpp | 26 +- .../kernels/quantization/test_cutlass_w4a8.py | 46 +- .../quantization/test_cutlass_w4a8_moe.py | 340 ++++++++++++ vllm/_custom_ops.py | 90 +++- .../layers/fused_moe/__init__.py | 4 + .../model_executor/layers/fused_moe/config.py | 29 ++ .../layers/fused_moe/cutlass_moe.py | 401 +++++++++++++++ .../layers/fused_moe/modular_kernel.py | 2 +- .../compressed_tensors/compressed_tensors.py | 2 +- .../compressed_tensors_moe.py | 340 +++++++++++- .../schemes/compressed_tensors_w4a8_fp8.py | 15 +- .../kernels/mixed_precision/cutlass.py | 19 +- .../layers/quantization/utils/quant_utils.py | 50 +- 22 files changed, 2045 insertions(+), 101 deletions(-) create mode 100644 csrc/quantization/cutlass_w4a8/get_group_starts.cuh create mode 100644 csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu create mode 100644 csrc/quantization/cutlass_w4a8/w4a8_utils.cu create mode 100644 csrc/quantization/cutlass_w4a8/w4a8_utils.cuh create mode 100644 tests/kernels/quantization/test_cutlass_w4a8_moe.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 69a538b06cba3..6b93e3fe91603 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -874,7 +874,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) set(SRCS - "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_utils.cu" + ) set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/ops.h b/csrc/ops.h index 5fce3a1a3fea3..37e3aaf7499d5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_problem_sizes( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets); + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt); void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, diff --git a/csrc/quantization/cutlass_w4a8/get_group_starts.cuh b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh new file mode 100644 index 0000000000000..fec142d0d87a1 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh @@ -0,0 +1,104 @@ +// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh +#pragma once + +#include +#include +#include + +#include "core/scalar_type.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + +// ElementB is int32 (packed int4) +// ElementGroupScale is cutlass::Array (packed fp8) +template +__global__ void get_group_gemm_starts( + int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scales_offsets, + ElementAccumulator** b_scales_offsets, + ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int, + ElementB* b_base_as_int, ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, + ElementAccumulator* b_scales_base_as_int, + ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k, + int64_t scale_k) { + int expert_id = threadIdx.x; + + int64_t expert_offset = expert_offsets[expert_id]; + + // same as w8a8 + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset; + b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id); + + // w4a8 specific + constexpr int pack_factor = 8; // pack 8 int4 into int32 + b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor); + b_group_scales_offsets[expert_id] = + b_group_scales_base_as_int + (expert_id * scale_k * n); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts> \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast**>( \ + b_group_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast*>( \ + b_group_scales.data_ptr()), \ + n, k, scale_k); \ + } + +namespace { + +void run_get_group_gemm_starts( + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor& out_tensors, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& b_group_scales, const int64_t b_group_size) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32 + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_group_scales.dtype() == + torch::kFloat8_e4m3fn); // the underlying torch type is e4m3 + TORCH_CHECK(out_tensors.dtype() == + torch::kBFloat16); // only support bf16 for now + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); + + int num_experts = static_cast(expert_offsets.size(0)); + // logical k, n + int64_t n = out_tensors.size(1); + int64_t k = a_tensors.size(1); + int64_t scale_k = cutlass::ceil_div(k, b_group_size); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +} // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu new file mode 100644 index 0000000000000..4b425790dbac7 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu @@ -0,0 +1,483 @@ +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +// vllm includes +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" +#include "cutlass_extensions/common.hpp" + +#include "core/registration.h" +#include "get_group_starts.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "w4a8_utils.cuh" + +namespace vllm::cutlass_w4a8_moe { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using ProblemShape = + cutlass::gemm::GroupProblemShape>; // per + // group +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; + +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr PackFactor = 8; // 8 int4 packed into int32 + +// A matrix configuration +using ElementA = MmaType; +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = + 128 / + cutlass::sizeof_bits::value; // Alignment of A matrix in units of + // elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input +// layouts +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; + +// Need to pass a pointer type to make the 3rd dimension of Stride be _0 +using StrideA = + cute::remove_pointer_t>; +using StrideB = + cute::remove_pointer_t>; + +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout>, StrideB>{})); + +using ElementScale = cutlass::float_e4m3_t; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + +// per-channel and per-token scales for epilogue +using ElementSChannel = float; + +template +struct W4A8GroupedGemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // per-channel, per-token scales epilogue + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogueArray; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementSChannel, ElementC, + typename cutlass::layout::LayoutTranspose::type*, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type*, + AlignmentD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + // =========================================================== MIXED INPUT + // WITH SCALES + // =========================================================================== + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered*, AlignmentB, ElementA, LayoutA_Transpose*, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, CollectiveMainloopShuffled, CollectiveEpilogue>; + + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::InternalStrideC; + using StrideD = typename GemmKernelShuffled::InternalStrideD; + + using StrideC_ref = cutlass::detail::TagToStrideC_t; + using StrideD_ref = cutlass::detail::TagToStrideC_t; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + using StrideS_ref = cutlass::detail::TagToStrideB_t; + + // static asserts for passing in strides/layouts + // pack to 2x int64 + static_assert(sizeof(StrideS) == 2 * sizeof(int64_t)); + // pack to 3xint32, + static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0, + "LayoutB_Reordered size must be divisible by 4 bytes"); + + static void grouped_mm( + torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides) { + auto device = a_tensors.device(); + auto device_id = device.index(); + const at::cuda::OptionalCUDAGuard device_guard(device); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + int num_experts = static_cast(expert_offsets.size(0)); + int n = static_cast(b_tensors.size(1)); + int k = static_cast(b_tensors.size(2)) * PackFactor; + + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(device); + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int); + + // get the correct offsets to pass to gemm + run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_scales_ptrs, b_scales_ptrs, b_group_scales_ptrs, + a_tensors, b_tensors, out_tensors, a_scales, + b_scales, b_group_scales, b_group_size); + + // construct args + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + Args arguments; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast( + problem_sizes_torch.data_ptr()); + ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; + + // SwapAB so B operands come first + MainloopArguments mainloop_arguments{ + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast**>( + b_group_scales_ptrs.data_ptr()), + static_cast(group_scale_strides.data_ptr()), + static_cast(b_group_size)}; + + EpilogueArguments epilogue_arguments{ + // since we are doing SwapAB the channel scales comes first, then token + // scales + ChTokScalesEpilogue::prepare_args( // see ScaledEpilogueArray + static_cast( + b_scales_ptrs.data_ptr()), // per-channel + static_cast( + a_scales_ptrs.data_ptr()), // per-token + true, true), + nullptr, // C + static_cast(c_strides.data_ptr()), // C + static_cast(out_ptrs.data_ptr()), // D + static_cast(c_strides.data_ptr()) // D + }; + + static const cutlass::KernelHardwareInfo hw_info{ + device_id, + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; + + arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, + mainloop_arguments, epilogue_arguments, hw_info}; + + // Allocate workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using CoopEpi = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + +// Kernel_TileShape_ClusterShape_Schedule +using Kernel_128x16_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_128x16_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x16_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x16_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x32_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x32_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x64_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x64_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x128_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x128_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_128x256_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +void mm_dispatch( + torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides, const std::string& schedule) { + if (schedule == "Kernel_128x16_1x1x1_Coop") { + Kernel_128x16_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_128x16_2x1x1_Coop") { + Kernel_128x16_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x16_1x1x1_Coop") { + Kernel_256x16_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x16_2x1x1_Coop") { + Kernel_256x16_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x32_1x1x1_Coop") { + Kernel_256x32_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x32_2x1x1_Coop") { + Kernel_256x32_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x64_1x1x1_Coop") { + Kernel_256x64_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x64_2x1x1_Coop") { + Kernel_256x64_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x128_1x1x1_Coop") { + Kernel_256x128_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x128_2x1x1_Coop") { + Kernel_256x128_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_128x256_2x1x1_Coop") { + Kernel_128x256_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else { + TORCH_CHECK(false, + "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); + } +} + +void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides, + std::optional maybe_schedule) { + // user has specified a schedule + if (maybe_schedule) { + mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + b_group_scales, b_group_size, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides, group_scale_strides, + *maybe_schedule); + return; + } + + // use heuristic + int m_full = a_tensors.size(0); + int n = b_tensors.size(1); + int k = b_tensors.size(2) * PackFactor; // logical k + int num_experts = b_tensors.size(0); + // per-expert batch size assuming uniform distribution + int m_expert = m_full / num_experts; + + std::string schedule; + if (m_expert <= 16) { + schedule = "Kernel_128x16_2x1x1_Coop"; + } else if (m_expert <= 32) { + schedule = "Kernel_256x32_1x1x1_Coop"; + } else if (m_expert <= 64) { + schedule = "Kernel_256x64_1x1x1_Coop"; + } else if (m_expert <= 128) { + schedule = "Kernel_256x128_2x1x1_Coop"; + } else { // m_expert > 128 + schedule = "Kernel_128x256_2x1x1_Coop"; + } + + mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + b_group_scales, b_group_size, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides, group_scale_strides, schedule); +} + +std::tuple encode_and_reorder_int4b( + torch::Tensor const& b_tensors) { + TORCH_CHECK(b_tensors.dtype() == torch::kInt32); + TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) + TORCH_CHECK(b_tensors.is_contiguous()); + TORCH_CHECK(b_tensors.is_cuda()); + + int n = static_cast(b_tensors.size(1)); + int k = static_cast(b_tensors.size(2)) * PackFactor; // logical k + + // CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0. + // These misalignments cause silent OOB unless run under Compute Sanitizer. + TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); + TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); + + // we will store the layout to an int32 tensor; + // this is the number of elements we need per layout + constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t); + + torch::Tensor b_tensors_packed = torch::empty_like(b_tensors); + int num_experts = static_cast(b_tensors.size(0)); + + auto b_ptr = static_cast(b_tensors.const_data_ptr()); + auto b_packed_ptr = static_cast(b_tensors_packed.data_ptr()); + + // multiply by ull so result does not overflow int32 + size_t num_int4_elems = 1ull * num_experts * n * k; + bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr, + num_int4_elems); + TORCH_CHECK(ok, "unified_encode_int4b failed"); + + // construct the layout once; assumes each expert has the same layout + using LayoutType = LayoutB_Reordered; + std::vector layout_B_reordered_host(num_experts); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, Int<1>{}}); + auto shape_B = cute::make_shape(n, k, Int<1>{}); + auto layout_B = make_layout(shape_B, stride_B); + LayoutType layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + + // reorder weights for each expert + for (int i = 0; i < num_experts; i++) { + // since the storage type of int4b is 1 byte but one element is 4 bits + // we need to adjust the offset + int64_t offset = + 1ull * i * n * k * cutlass::sizeof_bits::value / 8; + cutlass::reorder_tensor(b_packed_ptr + offset, layout_B, + layout_B_reordered); + } + + // save the packed layout to torch tensor so we can re-use it + auto cpu_opts = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + torch::Tensor layout_cpu = + torch::empty({num_experts, layout_width}, cpu_opts); + + int32_t* layout_data = layout_cpu.data_ptr(); + for (int i = 0; i < num_experts; ++i) { + std::memcpy(layout_data + i * layout_width, // dst (int32*) + &layout_B_reordered, // src (LayoutType*) + sizeof(LayoutType)); // number of bytes + } + + torch::Tensor packed_layout = + layout_cpu.to(b_tensors.device(), /*non_blocking=*/false); + + return {b_tensors_packed, packed_layout}; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_moe_mm", &mm); + m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8_moe +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu index 2d1568b08651c..f77af06cd6c08 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -7,6 +7,7 @@ #include #include #include "cutlass_extensions/torch_utils.hpp" +#include "w4a8_utils.cuh" #include "core/registration.h" @@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { return packed_scales; } -/* - GPU-accelerated implementation of cutlass::unified_encode_int4b. - Constructs a lookup table in constant memory to map 8 bits - (two 4-bit values) at a time. Assumes memory is contiguous - and pointers are 16-byte aligned. -*/ -__constant__ uint8_t kNibbleLUT[256]; - -__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, - size_t nbytes) { - constexpr size_t V = sizeof(uint4); // 16 bytes - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t nthreads = size_t(gridDim.x) * blockDim.x; - const size_t nvec = nbytes / V; - - // 1-D grid-stride loop over 16-byte chunks - for (size_t vec = tid; vec < nvec; vec += nthreads) { - uint4 v = reinterpret_cast(in)[vec]; - uint8_t* b = reinterpret_cast(&v); -#pragma unroll - for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; - reinterpret_cast(out)[vec] = v; - } -} - -static bool upload_lut() { - std::array lut{}; - auto map_nib = [](uint8_t v) -> uint8_t { - // 1..7 -> (8 - v); keep 0 and 8..15 - return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); - }; - for (int b = 0; b < 256; ++b) { - uint8_t lo = b & 0xF; - uint8_t hi = (b >> 4) & 0xF; - lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); - } - cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), - /*offset=*/0, cudaMemcpyHostToDevice); - - return (e == cudaSuccess); -} - -static bool unified_encode_int4b(cutlass::int4b_t const* in, - cutlass::int4b_t* out, size_t num_int4_elems) { - // Build/upload LUT - if (!upload_lut()) return false; - - static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, - "int4 storage must be 1 byte"); - const size_t nbytes = num_int4_elems >> 1; - - auto* in_bytes = reinterpret_cast(in); - auto* out_bytes = reinterpret_cast(out); - - // kernel launch params - constexpr int block = 256; - const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors - int grid = int((nvec + block - 1) / block); - if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel - - unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); - cudaError_t err = cudaGetLastError(); - return (err == cudaSuccess); -} - torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { TORCH_CHECK(B.dtype() == torch::kInt32); TORCH_CHECK(B.dim() == 2); @@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { LayoutB_Reordered layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); - bool ok = - vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr, + n * k); TORCH_CHECK(ok, "unified_encode_int4b failed"); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cu b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu new file mode 100644 index 0000000000000..f238d0a5b2d78 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu @@ -0,0 +1,90 @@ +#include "w4a8_utils.cuh" + +#include +#include +#include + +namespace vllm::cutlass_w4a8_utils { + +/* + GPU-accelerated implementation of cutlass::unified_encode_int4b. + Constructs a lookup table in constant memory to map 8 bits + (two 4-bit values) at a time. Assumes memory is contiguous + and pointers are 16-byte aligned. +*/ +__constant__ uint8_t kNibbleLUT[256]; + +__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, + size_t nbytes) { + constexpr size_t V = sizeof(uint4); // 16 bytes + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t nthreads = size_t(gridDim.x) * blockDim.x; + const size_t nvec = nbytes / V; + + // 1-D grid-stride loop over 16-byte chunks + for (size_t vec = tid; vec < nvec; vec += nthreads) { + uint4 v = reinterpret_cast(in)[vec]; + uint8_t* b = reinterpret_cast(&v); +#pragma unroll + for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; + reinterpret_cast(out)[vec] = v; + } +} + +static bool upload_lut() { + std::array lut{}; + auto map_nib = [](uint8_t v) -> uint8_t { + // 1..7 -> (8 - v); keep 0 and 8..15 + return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); + }; + for (int b = 0; b < 256; ++b) { + uint8_t lo = b & 0xF; + uint8_t hi = (b >> 4) & 0xF; + lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); + } + cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), + /*offset=*/0, cudaMemcpyHostToDevice); + + return (e == cudaSuccess); +} + +bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out, + size_t num_int4_elems) { + // Build/upload LUT + if (!upload_lut()) return false; + + static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, + "int4 storage must be 1 byte"); + const size_t nbytes = num_int4_elems >> 1; + + auto* in_bytes = reinterpret_cast(in); + auto* out_bytes = reinterpret_cast(out); + + // kernel launch params + constexpr int block = 256; + const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors + int grid = int((nvec + block - 1) / block); + if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel + + unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); + + // launch errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("unified_encode_int4b_device launch error: %s (%d)\n", + cudaGetErrorString(err), err); + return false; + } + + // runtime errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf("unified_encode_int4b_device runtime error: %s (%d)\n", + cudaGetErrorString(err), err); + return false; + } + + return true; +} + +} // namespace vllm::cutlass_w4a8_utils \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh b/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh new file mode 100644 index 0000000000000..25090091a368d --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh @@ -0,0 +1,11 @@ +#pragma once + +#include +#include "cutlass/numeric_types.h" + +namespace vllm::cutlass_w4a8_utils { + +bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out, + size_t num_int4_elems); + +} // namespace vllm::cutlass_w4a8_utils \ No newline at end of file diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index 49cafcc32adc6..99fec8fd6febc 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, void get_cutlass_moe_mm_problem_sizes_caller( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets) { + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt) { auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); // Swap-AB should be disabled for FP4 path - bool may_swap_ab = (!blockscale_offsets.has_value()) && - (topk_ids.numel() <= SWAP_AB_THRESHOLD); + bool may_swap_ab = + force_swap_ab.value_or((!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD)); launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, atomic_buffer, num_experts, n, k, stream, diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index c5012a8669317..5de21cfbbaafb 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -80,7 +80,8 @@ void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_problem_sizes_caller( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets); + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt); void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, @@ -303,14 +304,15 @@ void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_problem_sizes( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets) { + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt) { int32_t version_num = get_sm_version_num(); #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, - blockscale_offsets); + blockscale_offsets, force_swap_ab); return; #endif TORCH_CHECK_NOT_IMPLEMENTED( diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 62212f98b4766..d4c6f8c67c516 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -350,6 +350,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); // conditionally compiled so impl registration is in source file + // CUTLASS w4a8 grouped GEMM + ops.def( + "cutlass_w4a8_moe_mm(" + " Tensor! out_tensors," + " Tensor a_tensors," + " Tensor b_tensors," + " Tensor a_scales," + " Tensor b_scales," + " Tensor b_group_scales," + " int b_group_size," + " Tensor expert_offsets," + " Tensor problem_sizes," + " Tensor a_strides," + " Tensor b_strides," + " Tensor c_strides," + " Tensor group_scale_strides," + " str? maybe_schedule" + ") -> ()"); + ops.def( + "cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, " + "Tensor)"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. @@ -466,7 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! problem_sizes1, " " Tensor! problem_sizes2, " " int num_experts, int n, int k, " - " Tensor? blockscale_offsets) -> ()"); + " Tensor? blockscale_offsets, " + " bool? force_swap_ab) -> ()"); ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, &get_cutlass_moe_mm_problem_sizes); diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py index 465e24fd7eb97..cccef28f5e931 100644 --- a/tests/kernels/quantization/test_cutlass_w4a8.py +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -12,8 +12,11 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.quant_utils import ( + convert_packed_uint4b8_to_signed_int4_inplace, + pack_cols, pack_rows, quantize_weights, + unpack_quantized_values_into_int32, ) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -167,8 +170,7 @@ def create_test_tensors( # for the practical use case we need per-tok scales for fp8 activations w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type) - # weights are already per-group quantized, use placeholder here - w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type) + w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type) return Tensors( w_ref=w_ref, @@ -211,7 +213,7 @@ def mm_test_helper( print(output_ref) torch.testing.assert_close( - output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3 + output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2 ) @@ -257,7 +259,7 @@ def test_w4a8_cuda_graph(): ) w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32) - w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32) + w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32) # Construct a trivial model with a single layer that calls the kernel model = W4A8Layer( @@ -287,4 +289,38 @@ def test_w4a8_cuda_graph(): output.zero_() g.replay() - torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES) +def test_convert_packed_uint4b8_to_signed_int4_inplace(shape): + """ + The W4A16 checkpoints encode the weights as int4b8 packed to int32. + The CUTLASS kernels expect signed int4 packed to int32. + This tests checks that the runtime int4b8 -> signed int4 conversion + matches the offline conversion step exactly. + """ + _, N, K = shape + # random weights packed to int32 + t = torch.randint( + low=torch.iinfo(torch.int32).min, + high=torch.iinfo(torch.int32).max + 1, + size=(N, K // 8), + dtype=torch.int32, + device="cuda", + ) + + # compute reference + unpacked = unpack_quantized_values_into_int32( + t.clone(), scalar_types.uint4b8, packed_dim=1 + ) + unpacked = unpacked - 8 # int4b8 -> signed int4 + ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape) + + out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone()) + + assert torch.equal(ref, out) + assert not torch.equal(ref, t) diff --git a/tests/kernels/quantization/test_cutlass_w4a8_moe.py b/tests/kernels/quantization/test_cutlass_w4a8_moe.py new file mode 100644 index 0000000000000..3560402a29e90 --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8_moe.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer. +""" + +import random +from dataclasses import dataclass + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, + quantize_weights, +) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: torch.dtype | None, + group_size: int | None, + zero_points: bool = False, +): + """ + Quantize weights into W4 and compute reference dequantized weights. + + Encoding/reordering of weights and packing of scales is deferred + until after all experts are combined. + """ + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, wtype, group_size=group_size, zero_points=zero_points + ) + + # Since scales are later cast to fp8, recompute w_ref in atype here. + w_ref = ( + w_q.to(torch.float32) + * w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0) + ).to(atype) + + # Bit mask prevents sign extension of int4 when packing. + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + # Make weights row-major (N, K). + w_q = w_q.t().contiguous() + + return w_ref, w_q, w_s.to(atype), w_zp + + +def cutlass_preprocess( + w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor] +): + """ + Reorder/encode expert weights and pack scales. + + Returns: + w_q_packed: Packed/encoded int4 weights for all experts. + w_s_packed: Packed fp8 scales for all experts. + packed_layout: Layout/stride metadata for grouped GEMM. + """ + w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts)) + w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped( + torch.stack(w_q_experts) + ) # expects dim 3 + return w_q_packed, w_s_packed, packed_layout + + +GROUP_SIZE = 128 +# (num_experts, N, K) +TEST_SHAPES = [ + (8, 512, 2048), + (8, 2048, 2048), + (64, 512, 1024), + (64, 2048, 2048), + (4, 2048, 768), + (8, 768, 2048), + (64, 1536, 2048), + (128, 8192, 4096), # test overflow int32 +] +ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check + + +@dataclass +class MoETestSetup: + num_experts: int + K: int + N: int + Ms: list[int] + M_full: int + a: torch.Tensor + a_ref: torch.Tensor + a_strides: torch.Tensor + out: torch.Tensor + c_strides: torch.Tensor + per_tok_scales: torch.Tensor + per_chan_scales: torch.Tensor + w_refs: list[torch.Tensor] + w_q_packed: torch.Tensor + w_s_packed: torch.Tensor + problem_sizes: torch.Tensor + expert_offsets: torch.Tensor + b_strides: torch.Tensor + group_scale_strides: torch.Tensor + + +def make_moe_test_setup( + num_experts: int, + K: int, + N: int, + *, + alignment: int = ALIGNMENT, + max_blocks: int = 64, + device: str = "cuda", + random_zero: bool = False, +) -> MoETestSetup: + """Create a full set of tensors for testing cutlass_w4a8_moe_mm.""" + + assert K % GROUP_SIZE == 0 + # Token counts per expert (multiples of `alignment`). + Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)] + + # set random experts to 0 tokens + if random_zero and num_experts > 1: + num_zero = max(1, num_experts // 8) + zero_indices = random.sample(range(num_experts), k=num_zero) + for idx in zero_indices: + Ms[idx] = 0 + + M_full = sum(Ms) + assert M_full > 0 + + # Activations. + a = to_fp8(torch.randn((M_full, K), device=device)) + a_ref = a.to(torch.float32) + a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device) + + # Output buffer. + out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device) + c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device) + + # Channel/token scales. + per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device) + per_chan_scales = torch.randn( + (num_experts, N, 1), dtype=torch.float32, device=device + ) + + # Expert weights and scales. + wtype = scalar_types.int4 + atype = stype = torch.float8_e4m3fn + w_refs, w_qs, w_ss = [], [], [] + for _ in range(num_experts): + b = to_fp8(torch.randn((K, N), device=device)) + w_ref, w_q, w_s, _ = cutlass_quantize( + atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False + ) + w_refs.append(w_ref) + w_qs.append(w_q) + w_ss.append(w_s) + + w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss) + + problem_sizes = torch.tensor( + [[N, M, K] for M in Ms], dtype=torch.int32, device=device + ) + + expert_offsets = torch.cat( + [ + torch.tensor([0], dtype=torch.int64), + torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1], + ] + ).to(device=device) + + # B strides and group scale strides. + b_strides = packed_layout + group_scale_strides = torch.zeros( + (num_experts, 2), dtype=torch.int64, device=device + ) + group_scale_strides[:, 0] = N + + return MoETestSetup( + num_experts=num_experts, + K=K, + N=N, + Ms=Ms, + M_full=M_full, + a=a, + a_ref=a_ref, + a_strides=a_strides, + out=out, + c_strides=c_strides, + per_tok_scales=per_tok_scales, + per_chan_scales=per_chan_scales, + w_refs=w_refs, + w_q_packed=w_q_packed, + w_s_packed=w_s_packed, + problem_sizes=problem_sizes, + expert_offsets=expert_offsets, + b_strides=b_strides, + group_scale_strides=group_scale_strides, + ) + + +def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor: + """Compute reference output using torch._scaled_mm per expert.""" + out_ref = torch.empty_like(setup.out) + + ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist() + starts = setup.expert_offsets.cpu().tolist() + + for i in range(setup.num_experts): + start, end = starts[i], ends[i] + if start == end: + continue + + out_ref_i = torch._scaled_mm( + setup.a_ref[start:end].to(torch.float8_e4m3fn), + setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(), + setup.per_tok_scales[start:end], # (M, 1) + setup.per_chan_scales[i].reshape(1, -1), # (1, N) + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + out_ref[start:end] = out_ref_i + + return out_ref + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, + reason="W4A8 Grouped GEMM is not supported on this GPU type.", +) +@pytest.mark.parametrize("shape", TEST_SHAPES) +@pytest.mark.parametrize("random_zero", [True, False]) +def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero): + num_experts, N, K = shape + current_platform.seed_everything(42) + setup = make_moe_test_setup( + num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero + ) + + ops.cutlass_w4a8_moe_mm( + setup.out, + setup.a, + setup.w_q_packed, + setup.per_tok_scales, + setup.per_chan_scales, + setup.w_s_packed, + GROUP_SIZE, + setup.expert_offsets, + setup.problem_sizes, + setup.a_strides, + setup.b_strides, + setup.c_strides, + setup.group_scale_strides, + ) + torch.cuda.synchronize() + + out_ref = compute_moe_reference_output(setup) + torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2) + + +class W4A8MoELayer(torch.nn.Module): + """ + Minimal wrapper module to test cuda graphs + """ + + def __init__(self, setup: MoETestSetup): + super().__init__() + self.setup = setup + + def forward(self, a: torch.Tensor) -> torch.Tensor: + s = self.setup + ops.cutlass_w4a8_moe_mm( + s.out, + a, + s.w_q_packed, + s.per_tok_scales, + s.per_chan_scales, + s.w_s_packed, + GROUP_SIZE, + s.expert_offsets, + s.problem_sizes, + s.a_strides, + s.b_strides, + s.c_strides, + s.group_scale_strides, + ) + return s.out + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, + reason="W4A8 Grouped GEMM is not supported on this GPU type.", +) +def test_cutlass_w4a8_moe_mm_cuda_graph(): + current_platform.seed_everything(42) + # Fixed config for CUDA graph test (single parameter point). + num_experts = 8 + K = 512 + N = 2048 + + setup = make_moe_test_setup( + num_experts=num_experts, + K=K, + N=N, + max_blocks=32, + ) + + # Construct model that calls the grouped GEMM kernel. + model = W4A8MoELayer(setup) + + # Build reference output once. + out_ref = compute_moe_reference_output(setup) + + # Capture and run the model in a CUDA graph. + a_static = setup.a.clone() # static input tensor for graph replay + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out_static = model(a_static) + + out_static.zero_() + g.replay() + + torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 56c780ceb1cb5..6bbfe11b6e925 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -695,6 +695,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: return torch.empty_like(b, memory_format=torch.contiguous_format) + @register_fake("_C::cutlass_encode_and_reorder_int4b_grouped") + def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @@ -1058,6 +1062,7 @@ def get_cutlass_moe_mm_problem_sizes( n: int, k: int, blockscale_offsets: torch.Tensor | None = None, + force_swap_ab: bool | None = None, ): """ Compute only the per-expert problem sizes needed by the two grouped matrix @@ -1067,9 +1072,20 @@ def get_cutlass_moe_mm_problem_sizes( - problem_sizes1, problem_sizes2: M×N×K sizes of each expert's multiplication for the two grouped MMs used in the fused MoE operation. + Optional: + - force_swap_ab: If set to True or False, explicitly enable or disable the + A/B input swap optimization. If None (default), the swap + is selected automatically based on tensor sizes. """ return torch.ops._C.get_cutlass_moe_mm_problem_sizes( - topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets + topk_ids, + problem_sizes1, + problem_sizes2, + num_experts, + n, + k, + blockscale_offsets, + force_swap_ab, ) @@ -1457,6 +1473,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: return torch.ops._C.cutlass_encode_and_reorder_int4b(b) +def cutlass_w4a8_moe_mm( + out_tensors: torch.Tensor, + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, + a_strides: torch.Tensor, + b_strides: torch.Tensor, + c_strides: torch.Tensor, + group_scale_strides: torch.Tensor, + maybe_schedule: str | None = None, +): + """ + Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the + W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8) + and both per-channel + per-token scaling in the epilogue. + + Args: + out_tensors: + Output buffer for all experts (updated in-place). + a_tensors: + FP8 (E4M3FN) activations for all experts. + b_tensors: + INT4-packed weight matrix for all experts, packed to INT32 + a_scales: + Per-token FP8 activation scales, applied in the epilogue. + b_scales: + Per-channel FP8 weight scales for each expert, applied in the epilogue. + b_group_scales: + FP8 scale values for group-wise INT4 weight blocks. + b_group_size: + Number of elements grouped under each entry of b_group_scales. + expert_offsets: + Cumulative token offsets + problem_sizes: + Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher. + a/b/c/group_scale_strides: + Strides describing the memory layout of the input tensors. + maybe_schedule: + Optional override to choose a specific kernel or epilogue schedule. + + Returns: + out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result. + """ + return torch.ops._C.cutlass_w4a8_moe_mm( + out_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + b_group_scales, + b_group_size, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + c_strides, + group_scale_strides, + maybe_schedule, + ) + + +def cutlass_encode_and_reorder_int4b_grouped( + b_tensors: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9103e84aa7057..1e145a8fcd791 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -63,8 +63,10 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassBatchedExpertsFp8, CutlassExpertsFp8, + CutlassExpertsW4A8Fp8, cutlass_moe_fp4, cutlass_moe_fp8, + cutlass_moe_w4a8_fp8, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( @@ -88,8 +90,10 @@ if HAS_TRITON: "grouped_topk", "cutlass_moe_fp8", "cutlass_moe_fp4", + "cutlass_moe_w4a8_fp8", "CutlassExpertsFp8", "CutlassBatchedExpertsFp8", + "CutlassExpertsW4A8Fp8", "TritonExperts", "BatchedTritonExperts", "DeepGemmExperts", diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index e52845dfa246d..f35cafa0f77dc 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -143,6 +143,7 @@ class FusedMoEQuantDesc: scale: Union[torch.Tensor, "PrecisionConfig", None] = None # Quantization alphas or gscales, used for nvfp4 types. + # W4A8 FP8: used for per-channel scales # TODO(bnell): put some of these in subclasses alpha_or_gscale: torch.Tensor | None = None @@ -442,7 +443,9 @@ class FusedMoEQuantConfig: - a1_scale: Optional scale to be used for a1. - a2_scale: Optional scale to be used for a2. - g1_alphas: Optional global quantization scales for w1 (for nvfp4). + per-channel scales for w1 (for W4A8 FP8). - g2_alphas: Optional global quantization scales for w2 (for nvfp4). + per-channel scales for w2 (for W4A8 FP8). - a1_gscale: Optional global quantization scales for a1 (for nvfp4). - a2_gscale: Optional global quantization scales for a2 (for nvfp4). - w1_bias: Optional biases for w1 (GPT OSS Triton). @@ -461,6 +464,7 @@ class FusedMoEQuantConfig: "mxfp4", "mxfp6_e3m2", "mxfp6_e2m3", + "int4", } if weight_dtype is None: @@ -671,6 +675,31 @@ def int8_w8a16_moe_quant_config( ) +def int4_w4afp8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for fp8 activations and int4 weights. + """ + return FusedMoEQuantConfig.make( + torch.float8_e4m3fn, # quant dtype for activations + w1_scale=w1_scale, + w2_scale=w2_scale, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + weight_dtype="int4", # weight dtype for weights + ) + + def biased_moe_quant_config( w1_bias: torch.Tensor | None, w2_bias: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 30144ca5452eb..552e38a71bf98 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1052,3 +1052,404 @@ def run_cutlass_block_scaled_fused_experts( return ( c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) ).sum(dim=1) + + +# W4A8 +def run_cutlass_moe_w4a8_fp8( + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation_callable: Callable, + global_num_experts: int, + expert_map: torch.Tensor | None, + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + w1_chan_scale: torch.Tensor, + w2_chan_scale: torch.Tensor, + a_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides1: torch.Tensor, + b_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + s_strides1: torch.Tensor, + s_strides2: torch.Tensor, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: torch.Tensor | None, + out_dtype: torch.dtype, + per_act_token: bool, + per_out_ch: bool, + use_batched_format: bool, + topk_weights: torch.Tensor | None, + group_size: int, +): + a1q = hidden_states + M = a1q.size(0) + local_E = w1.size(0) + device = a1q.device + _, K, N_packed = w2.shape + N = N_packed * 8 # logical N, pack 8 int4 into 1 int32 + + assert per_act_token, "W4A8 must use per-token scales" + assert per_out_ch, "W4A8 must use per-channel scales" + assert w1_scale is not None + assert w2_scale is not None + assert w1_scale.dtype == torch.float8_e4m3fn + assert w2_scale.dtype == torch.float8_e4m3fn + assert w1.dtype == torch.int32 + assert w2.dtype == torch.int32 + assert w1_chan_scale.dtype == torch.float32 + assert w2_chan_scale.dtype == torch.float32 + assert w1.size(0) == w2.size(0), "Weights expert number mismatch" + assert a1q_scale is not None + assert a2_scale is None + assert out_dtype in [torch.bfloat16], f"Invalid output dtype: {out_dtype}" + if expert_map is not None: + assert expert_num_tokens is None + assert not use_batched_format, "batched format not supported yet" + assert group_size == 128, f"Only group size 128 supported but got {group_size=}" + + assert global_num_experts != -1 + assert w1.size(2) * 8 == K, ( + f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}" + ) + + # Translate info from expert_map to topk_ids + if expert_map is not None: + local_topk_ids = torch.where( + expert_map[topk_ids] != -1, expert_map[topk_ids], -1 + ) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.size(1) + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K)) + mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) + act_out = _resize_cache(workspace2, (M * topk, N)) + # original workspace are based on input hidden_states dtype (bf16) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N) + ) + mm2_out = _resize_cache(workspace2, (M * topk, K)) + + problem_sizes1 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) + problem_sizes2 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) + + num_expert = global_num_experts if expert_map is None else expert_map.size(0) + # permuted a1q reuses workspace2 + a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( + a1q, + a1q_scale, + topk_ids, + num_expert, + local_E, + expert_map, + permuted_hidden_states=a1q_perm, + ) + expert_offsets = expert_offsets[:-1] + + # For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape) + ops.get_cutlass_moe_mm_problem_sizes( + local_topk_ids, + problem_sizes1, + problem_sizes2, + global_num_experts, + N, + K, + force_swap_ab=True, + ) + + ops.cutlass_w4a8_moe_mm( + mm1_out, + a1q, + w1, + a1q_scale, + w1_chan_scale, + w1_scale, + group_size, + expert_offsets, + problem_sizes1, + a_strides1, + b_strides1, + c_strides1, + s_strides1, + ) + + activation_callable(act_out, mm1_out) + + a2q, a2q_scale = ops.scaled_fp8_quant( + act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out + ) + + if expert_map is not None: + mm2_out.fill_(0) + + ops.cutlass_w4a8_moe_mm( + mm2_out, + a2q, + w2, + a2q_scale, + w2_chan_scale, + w2_scale, + group_size, + expert_offsets, + problem_sizes2, + a_strides2, + b_strides2, + c_strides2, + s_strides2, + ) + + # for non-chunking mode the output is resized from workspace13 + # so we need to make sure mm2_out uses workspace2. + moe_unpermute( + out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm, + ) + + +class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + out_dtype: torch.dtype | None, + a_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides1: torch.Tensor, + b_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + s_strides1: torch.Tensor, + s_strides2: torch.Tensor, + quant_config: FusedMoEQuantConfig, + group_size: int, + ): + super().__init__(quant_config) + self.out_dtype = out_dtype + self.a_strides1 = a_strides1 + self.a_strides2 = a_strides2 + self.b_strides1 = b_strides1 + self.b_strides2 = b_strides2 + self.c_strides1 = c_strides1 + self.c_strides2 = c_strides2 + self.s_strides1 = s_strides1 + self.s_strides2 = s_strides2 + self.group_size = group_size + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # topk weights and reduction are fused in moe_unpermute cuda kernel + return TopKWeightAndReduceNoOP() + + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk, max(N // 2, K)) + output = (M, K) + return (workspace1, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + + expert_num_tokens = None + activation_callable = lambda o, i: self.activation(activation, o, i) + + use_batched_format = ( + self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + ) + assert not use_batched_format, "batched format not supported" + + in_dtype = hidden_states.dtype + + run_cutlass_moe_w4a8_fp8( + output, + hidden_states, + w1, + w2, + topk_ids, + activation_callable, + global_num_experts, + expert_map, + self.w1_scale, + self.w2_scale, + a1q_scale, + a2_scale, + self.g1_alphas, # per-channel scales + self.g2_alphas, # per-channel scales + self.a_strides1, + self.a_strides2, + self.b_strides1, + self.b_strides2, + self.c_strides1, + self.c_strides2, + self.s_strides1, + self.s_strides2, + workspace13, + workspace2, + expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, + self.per_out_ch_quant, + use_batched_format, + topk_weights, + self.group_size, + ) + + +def cutlass_moe_w4a8_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + a_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides1: torch.Tensor, + b_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + s_strides1: torch.Tensor, + s_strides2: torch.Tensor, + quant_config: FusedMoEQuantConfig, + activation: str = "silu", + expert_map: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + group_size: int = 128, +) -> torch.Tensor: + """ + This function computes a w4a8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + mixed-dtype grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, 2*N, K // packed_factor] + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, K, N // packed_factor] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mappings. + - a_strides1 (torch.Tensor): The input strides for the first gemm. + Shape: [num_experts] + - a_strides2 (torch.Tensor): The input strides for the second gemm. + Shape: [num_experts] + - b_strides1 (torch.Tensor): The packed layout for the first gemm weights. + Shape: [num_experts, 3] + dtype: torch.int32 + - b_strides2 (torch.Tensor): The packed layout for the second gemm weights. + Shape: [num_experts, 3] + dtype: torch.int32 + - c_strides1 (torch.Tensor): The output strides for the first gemm. + Shape: [num_experts] + - c_strides2 (torch.Tensor): The output strides for the second gemm. + Shape: [num_experts] + - s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm. + Shape: [num_experts, 2] + dtype: torch.int64 + - s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm. + Shape: [num_experts, 2] + dtype: torch.int64 + - per_act_token (Optional[bool]): Whether the scale is per-token or + per-tensor. + - activation (str): The activation function to use. + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for a subset of experts. expert_map is a + mapping from global expert-id to local expert-id. When expert_map[i] + is -1, it means that this Rank is not responsible for global + expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. + - global_num_experts (int): The total number of experts. + - group_size (int): The number of weights per scale factor + + Returns: + - torch.Tensor: The bf16 output tensor after applying the MoE layer. + """ + assert quant_config is not None + + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) + + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsW4A8Fp8( + out_dtype=a.dtype, + a_strides1=a_strides1, + a_strides2=a_strides2, + b_strides1=b_strides1, + b_strides2=b_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, + s_strides1=s_strides1, + s_strides2=s_strides2, + quant_config=quant_config, + group_size=group_size, + ), + ) + + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 51d3299e7ddf1..075610ec588ae 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -367,7 +367,7 @@ class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPermuteExpertsUnpermute(ABC): """ An abstract base class for the [Permute-Experts-Unpermute] step described - above. + above. """ def __init__( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b91ecb59fee18..21f4cfe51d08e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -256,7 +256,7 @@ class CompressedTensorsConfig(QuantizationConfig): if format is not None else is_activation_quantization_format(quant_format) ) - # TODO(czhu): w4a8fp8 is in packed-quantized format + # w4a8fp8 is in packed-quantized format # but needs input activation quantization input_activations = quant_config.get("input_activations") if act_quant_format or input_activations: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8013b29f733bb..619162272c94f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -33,6 +33,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, int4_w4a16_moe_quant_config, + int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, int8_w8a16_moe_quant_config, nvfp4_moe_quant_config, @@ -79,7 +80,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + convert_bf16_scales_to_fp8, + convert_packed_uint4b8_to_signed_int4_inplace, + swizzle_blockscale, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, @@ -204,6 +209,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): return CompressedTensorsW8A8Int8MoEMethod( weight_quant, input_quant, layer.moe_config ) + elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod") + return CompressedTensorsW4A8Fp8MoEMethod( + weight_quant, input_quant, layer.moe_config + ) elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): return CompressedTensorsW4A8Int8MoEMethod( weight_quant, input_quant, layer.moe_config @@ -2428,3 +2438,331 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input, int(_act_kind(activation)), ) + + +class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + + self.group_size = self.weight_quant.group_size + self.num_bits = self.weight_quant.num_bits + self.packed_factor = 32 // self.num_bits + + assert self.weight_quant.symmetric, ( + "Only symmetric quantization is supported for W4A8 MoE" + ) + assert self.weight_quant.actorder != "group" + assert self.group_size == 128, "Only group size 128 supported for W4A8 MoE" + + self.disable_expert_map = False + self.layer_name = layer_name + + from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + + self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # requirement for CUTLASS reorder_tensor + assert hidden_size % 256 == 0, f"{hidden_size=} must be divisible by 256" + assert intermediate_size_per_partition % 256 == 0, ( + f"{intermediate_size_per_partition=} must be divisible by 256" + ) + # storage type, pack 8xint4 into int32 + params_dtype = torch.int32 + + # WEIGHTS + w13_weight_packed = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.packed_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight_packed) + set_weight_attrs(w13_weight_packed, extra_weight_attrs) + + w2_weight_packed = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.packed_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight_packed) + set_weight_attrs(w2_weight_packed, extra_weight_attrs) + + # SCALES + # weight_scale refers to the group-wise scales + # they are initially loaded as bf16, we will convert to fp8 + # after loading + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-GROUP quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # weight shapes + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + # don't use input scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer): + device = layer.w13_weight_packed.device + + # STRIDES + # A, C + self.a_strides1_c_strides2 = torch.full( + (layer.local_num_experts,), + layer.hidden_size, + device=device, + dtype=torch.int64, + ) + self.a_strides2 = torch.full( + (layer.local_num_experts,), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + self.c_strides1 = torch.full( + (layer.local_num_experts,), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + + # S (group-wise scales) + # sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it + self.s_strides1 = torch.zeros( + (layer.local_num_experts, 2), device=device, dtype=torch.int64 + ) + self.s_strides1[:, 0] = 2 * layer.intermediate_size_per_partition + + self.s_strides2 = torch.zeros( + (layer.local_num_experts, 2), device=device, dtype=torch.int64 + ) + self.s_strides2[:, 0] = layer.hidden_size + + # encode and reorder weight tensors, and get the layout to pass to + # the grouped gemm kernel. `b_strides1/2` specifies the entire layout + convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed) + w13_weight_shuffled, self.b_strides1 = ( + ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed) + ) + replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled) + convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed) + w2_weight_shuffled, self.b_strides2 = ( + ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed) + ) + replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled) + + # convert bf16 scales to (fp8_scales, channel_scales) + w13_weight_scale, w13_weight_chan_scale = convert_bf16_scales_to_fp8( + self.quant_fp8, layer.w13_weight_scale + ) + w2_weight_scale, w2_weight_chan_scale = convert_bf16_scales_to_fp8( + self.quant_fp8, layer.w2_weight_scale + ) + + # register channel scales + layer.register_parameter( + "w13_weight_chan_scale", + torch.nn.Parameter(w13_weight_chan_scale, requires_grad=False), + ) + layer.register_parameter( + "w2_weight_chan_scale", + torch.nn.Parameter(w2_weight_chan_scale, requires_grad=False), + ) + + # The scales are stored as (E, N, K // 128) but the kernel expects + # (E, K // 128, N) in row-major format, so we need to permute the last 2 dims + # and make it contiguous + w13_weight_scale_packed = ops.cutlass_pack_scale_fp8( + w13_weight_scale.permute(0, 2, 1).contiguous() + ) + replace_parameter(layer, "w13_weight_scale", w13_weight_scale_packed) + w2_weight_scale_packed = ops.cutlass_pack_scale_fp8( + w2_weight_scale.permute(0, 2, 1).contiguous() + ) + replace_parameter(layer, "w2_weight_scale", w2_weight_scale_packed) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalize | None: + return super().maybe_make_prepare_finalize(routing_tables) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + # Store quantization scales; both per-group and per-channel + # Note we haven't specified the group size here because + # the quant config logic assumes group-wise scaling + # and channel-wise scaling are exclusive. + return int4_w4afp8_moe_quant_config( + w1_scale=layer.w13_weight_scale, # group scale + w2_scale=layer.w2_weight_scale, # group scale + g1_alphas=layer.w13_weight_chan_scale, + g2_alphas=layer.w2_weight_chan_scale, + per_act_token_quant=True, # always use dynamc per-token + per_out_ch_quant=True, # always use per-channel + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None + assert ( + prepare_finalize.activation_format == FusedMoEActivationFormat.Standard + ), "BatchedExperts not supported" + + from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8 + + experts: FusedMoEPermuteExpertsUnpermute + + logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__) + experts = CutlassExpertsW4A8Fp8( + out_dtype=self.moe.in_dtype, + a_strides1=self.a_strides1_c_strides2, + a_strides2=self.a_strides2, + b_strides1=self.b_strides1, + b_strides2=self.b_strides2, + c_strides1=self.c_strides1, + c_strides2=self.a_strides1_c_strides2, + s_strides1=self.s_strides1, + s_strides2=self.s_strides2, + quant_config=self.moe_quant_config, + group_size=self.group_size, + ) + + num_dispatchers = prepare_finalize.num_dispatchers() + self.disable_expert_map = ( + num_dispatchers > 1 or not experts.supports_expert_map() + ) + + return experts + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." + ) + assert self.moe_quant_config is not None + topk_weights, topk_ids, _ = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + ) + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_w4a8_fp8, + ) + + return cutlass_moe_w4a8_fp8( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + topk_weights, + topk_ids, + quant_config=self.moe_quant_config, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + a_strides1=self.a_strides1_c_strides2, + a_strides2=self.a_strides2, + b_strides1=self.b_strides1, + b_strides2=self.b_strides2, + c_strides1=self.c_strides1, + c_strides2=self.a_strides1_c_strides2, + s_strides1=self.s_strides1, + s_strides2=self.s_strides2, + group_size=self.group_size, + ) + + @property + def supports_eplb(self) -> bool: + return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py index a23961e897534..9a25e08cbad75 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -128,14 +128,15 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): ), ) - # TODO(czhu): allocate the packed fp8 scales memory here? - # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + # After loading, we will transform bf16 -> fp8 -> + # expand by 8x via `cutlass_pack_scale_fp8` + # and construct per-channel fp32 scales. weight_scale_args = { "weight_loader": weight_loader, "data": torch.empty( output_size_per_partition, scales_and_zp_size, - dtype=torch.float8_e4m3fn, + dtype=params_dtype, ), } @@ -152,17 +153,9 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader ) - # per-channel scales - weight_chan_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) - layer.register_parameter("weight_chan_scale", weight_chan_scale) self.kernel = kernel_type( mp_linear_kernel_config, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py index 8ef6457c952f1..c9c1a3abf7fd3 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -6,7 +6,11 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + convert_bf16_scales_to_fp8, + convert_packed_uint4b8_to_signed_int4_inplace, +) from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel): "CUTLASS W4A8, only supported int4", ) - # TODO(czhu): support -1 (column-wise) if c.group_size != 128: return False, "Only group_size 128 is supported" @@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel): # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): - # TODO(czhu): optimize speed/mem usage def transform_w_q(x): assert isinstance(x, BasevLLMParameter) + convert_packed_uint4b8_to_signed_int4_inplace(x.data) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t()) return x @@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel): x.data = ops.cutlass_pack_scale_fp8(x.data) return x + w_s = getattr(layer, self.w_s_name) + fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data) + w_s.data = fp8_scales + + # register per-channel scales + layer.register_parameter( + "weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False) + ) + # Encode/reorder weights and pack scales self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - self._transform_param(layer, "weight_chan_scale", lambda x: x) def apply_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 92ee8c498e01f..d01263f82007d 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" -from collections.abc import Mapping +from collections.abc import Callable, Mapping from dataclasses import dataclass from types import MappingProxyType from typing import ClassVar, NamedTuple @@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool: capability_tuple = current_platform.get_device_capability() capability = -1 if capability_tuple is None else capability_tuple.to_int() return cutlass_scaled_mm_supports_fp4(capability) + + +def convert_bf16_scales_to_fp8( + quant_fp8: Callable, scales: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales) + expected by W4A8 GEMM kernels. + """ + assert scales.is_contiguous(), ( + f"scale tensor must be contiguous, got {scales.stride()=}" + ) + assert scales.is_cuda, "scales must be on gpu" + + orig_shape = scales.shape + k_groups = orig_shape[-1] + flat_scales = scales.view(-1, k_groups) + + fp8_scales, chan_scales = quant_fp8(flat_scales) + fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn) + chan_scales *= 8.0 + + # restore original shape + fp8_scales = fp8_scales.view(orig_shape) + chan_scales = chan_scales.view(orig_shape[:-1], -1) + + return fp8_scales, chan_scales + + +def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor: + """ + Convert int4b8 (packed to int32) to signed int4 + """ + assert t.is_cuda, "tensor must be on gpu" + assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}" + + # loop through the 8 4-bit nibbles in each int32 entry + for i in range(8): + shift = 4 * i + # extract the i-th 4-bit nibble + nib = (t >> shift) & 0xF + # clear the original nibble by masking out + t &= ~(0xF << shift) + # convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8 + # and update in-place + t |= ((nib - 8) & 0xF) << shift + + return t