[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere 2025-12-08 22:29:06 -05:00 committed by GitHub
parent ea657f2078
commit f6227c22ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 2045 additions and 101 deletions

View File

@ -874,7 +874,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS set(SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"

View File

@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
void get_cutlass_moe_mm_problem_sizes( void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets); const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes1,

View File

@ -0,0 +1,104 @@
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
// ElementB is int32 (packed int4)
// ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
template <typename ElementA, typename ElementB, typename ElementC,
typename ElementAccumulator, typename ElementGroupScale>
__global__ void get_group_gemm_starts(
int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets,
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
ElementAccumulator** b_scales_offsets,
ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int,
ElementB* b_base_as_int, ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
ElementAccumulator* b_scales_base_as_int,
ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k,
int64_t scale_k) {
int expert_id = threadIdx.x;
int64_t expert_offset = expert_offsets[expert_id];
// same as w8a8
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset;
b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id);
// w4a8 specific
constexpr int pack_factor = 8; // pack 8 int4 into int32
b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor);
b_group_scales_offsets[expert_id] =
b_group_scales_base_as_int + (expert_id * scale_k * n);
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int64_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<int32_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>**>( \
b_group_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<int32_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( \
b_group_scales.data_ptr()), \
n, k, scale_k); \
}
namespace {
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_group_scales.dtype() ==
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
TORCH_CHECK(out_tensors.dtype() ==
torch::kBFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n
int64_t n = out_tensors.size(1);
int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace

View File

@ -0,0 +1,483 @@
#include <vector>
#include <tuple>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh"
namespace vllm::cutlass_w4a8_moe {
using namespace cute;
// -------------------------------------------------------------------------------------
// Static configuration shared across all instantiations
// -------------------------------------------------------------------------------------
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per
// group
using MmaType = cutlass::float_e4m3_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
static int constexpr PackFactor = 8; // 8 int4 packed into int32
// A matrix configuration
using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA =
128 /
cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of
// elements (up to 16 bytes)
// B matrix configuration
using ElementB = QuantType; // Element type for B matrix operand
using LayoutB =
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB =
128 / cutlass::sizeof_bits<
ElementB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input
// layouts
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
using StrideA =
cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
using StrideB =
cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in
// contiguous locations in global memory. It specifies the reordering within a
// single warp's fragment
using LayoutAtomQuant =
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(
LayoutAtomQuant{}, Layout<Shape<int, int, Int<1>>, StrideB>{}));
using ElementScale = cutlass::float_e4m3_t;
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementC =
cutlass::bfloat16_t; // Element type for C and D matrix operands
using LayoutC =
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC =
128 / cutlass::sizeof_bits<
ElementC>::value; // Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
// on the tile size
// per-channel and per-token scales for epilogue
using ElementSChannel = float;
template <class TileShape_MN, class ClusterShape_MNK, class KernelSchedule,
class EpilogueSchedule>
struct W4A8GroupedGemmKernel {
using TileShape =
decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{}));
using ClusterShape = ClusterShape_MNK;
// per-channel, per-token scales epilogue
using ChTokScalesEpilogue =
typename vllm::c3x::ScaledEpilogueArray<ElementAccumulator, ElementD,
TileShape>;
using EVTCompute = typename ChTokScalesEpilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementSChannel, ElementC,
typename cutlass::layout::LayoutTranspose<LayoutC>::type*, AlignmentC,
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type*,
AlignmentD, EpilogueSchedule, EVTCompute>::CollectiveOp;
// =========================================================== MIXED INPUT
// WITH SCALES
// ===========================================================================
// The Scale information must get paired with the operand that will be scaled.
// In this example, B is scaled so we make a tuple of B's information and the
// scale information.
using CollectiveMainloopShuffled =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
LayoutB_Reordered*, AlignmentB, ElementA, LayoutA_Transpose*,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
ProblemShape, CollectiveMainloopShuffled, CollectiveEpilogue>;
using GemmShuffled =
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
using StrideC = typename GemmKernelShuffled::InternalStrideC;
using StrideD = typename GemmKernelShuffled::InternalStrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
// static asserts for passing in strides/layouts
// pack to 2x int64
static_assert(sizeof(StrideS) == 2 * sizeof(int64_t));
// pack to 3xint32,
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
"LayoutB_Reordered size must be divisible by 4 bytes");
static void grouped_mm(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides) {
auto device = a_tensors.device();
auto device_id = device.index();
const at::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream(device_id);
int num_experts = static_cast<int>(expert_offsets.size(0));
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(device);
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
// get the correct offsets to pass to gemm
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
a_scales_ptrs, b_scales_ptrs, b_group_scales_ptrs,
a_tensors, b_tensors, out_tensors, a_scales,
b_scales, b_group_scales, b_group_size);
// construct args
using Args = typename GemmShuffled::Arguments;
using MainloopArguments = typename GemmKernelShuffled::MainloopArguments;
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
Args arguments;
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<ProblemShape::UnderlyingProblemShape*>(
problem_sizes_torch.data_ptr());
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
// SwapAB so B operands come first
MainloopArguments mainloop_arguments{
static_cast<const QuantType**>(b_ptrs.data_ptr()),
static_cast<LayoutB_Reordered*>(b_strides.data_ptr()),
static_cast<const MmaType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides.data_ptr()),
static_cast<const cutlass::Array<ElementScale, 8>**>(
b_group_scales_ptrs.data_ptr()),
static_cast<StrideS*>(group_scale_strides.data_ptr()),
static_cast<int>(b_group_size)};
EpilogueArguments epilogue_arguments{
// since we are doing SwapAB the channel scales comes first, then token
// scales
ChTokScalesEpilogue::prepare_args( // see ScaledEpilogueArray
static_cast<const ElementAccumulator**>(
b_scales_ptrs.data_ptr()), // per-channel
static_cast<const ElementAccumulator**>(
a_scales_ptrs.data_ptr()), // per-token
true, true),
nullptr, // C
static_cast<StrideC*>(c_strides.data_ptr()), // C
static_cast<ElementD**>(out_ptrs.data_ptr()), // D
static_cast<StrideC*>(c_strides.data_ptr()) // D
};
static const cutlass::KernelHardwareInfo hw_info{
device_id,
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};
arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape,
mainloop_arguments, epilogue_arguments, hw_info};
// Allocate workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
// Run GEMM
GemmShuffled gemm;
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(gemm.run(stream));
}
};
// ----------------------------------------------------------------------------
// Kernel instantiations and dispatch logic
// ----------------------------------------------------------------------------
using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using CoopEpi = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
// Kernel_TileShape_ClusterShape_Schedule
using Kernel_128x16_1x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>, Coop, CoopEpi>;
using Kernel_128x16_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _16>, Shape<_2, _1, _1>, Coop, CoopEpi>;
using Kernel_256x16_1x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>, Coop, CoopEpi>;
using Kernel_256x16_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _16>, Shape<_2, _1, _1>, Coop, CoopEpi>;
using Kernel_256x32_1x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>, Coop, CoopEpi>;
using Kernel_256x32_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _32>, Shape<_2, _1, _1>, Coop, CoopEpi>;
using Kernel_256x64_1x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>, Coop, CoopEpi>;
using Kernel_256x64_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _64>, Shape<_2, _1, _1>, Coop, CoopEpi>;
using Kernel_256x128_1x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>, Coop, CoopEpi>;
using Kernel_256x128_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_256, _128>, Shape<_2, _1, _1>, Coop, CoopEpi>;
using Kernel_128x256_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
void mm_dispatch(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides, const std::string& schedule) {
if (schedule == "Kernel_128x16_1x1x1_Coop") {
Kernel_128x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_128x16_2x1x1_Coop") {
Kernel_128x16_2x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x16_1x1x1_Coop") {
Kernel_256x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x16_2x1x1_Coop") {
Kernel_256x16_2x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x32_1x1x1_Coop") {
Kernel_256x32_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x32_2x1x1_Coop") {
Kernel_256x32_2x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x64_1x1x1_Coop") {
Kernel_256x64_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x64_2x1x1_Coop") {
Kernel_256x64_2x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x128_1x1x1_Coop") {
Kernel_256x128_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_256x128_2x1x1_Coop") {
Kernel_256x128_2x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else if (schedule == "Kernel_128x256_2x1x1_Coop") {
Kernel_128x256_2x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else {
TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
}
}
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides,
std::optional<std::string> maybe_schedule) {
// user has specified a schedule
if (maybe_schedule) {
mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
b_group_scales, b_group_size, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides, group_scale_strides,
*maybe_schedule);
return;
}
// use heuristic
int m_full = a_tensors.size(0);
int n = b_tensors.size(1);
int k = b_tensors.size(2) * PackFactor; // logical k
int num_experts = b_tensors.size(0);
// per-expert batch size assuming uniform distribution
int m_expert = m_full / num_experts;
std::string schedule;
if (m_expert <= 16) {
schedule = "Kernel_128x16_2x1x1_Coop";
} else if (m_expert <= 32) {
schedule = "Kernel_256x32_1x1x1_Coop";
} else if (m_expert <= 64) {
schedule = "Kernel_256x64_1x1x1_Coop";
} else if (m_expert <= 128) {
schedule = "Kernel_256x128_2x1x1_Coop";
} else { // m_expert > 128
schedule = "Kernel_128x256_2x1x1_Coop";
}
mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
b_group_scales, b_group_size, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides, group_scale_strides, schedule);
}
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
torch::Tensor const& b_tensors) {
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
TORCH_CHECK(b_tensors.is_contiguous());
TORCH_CHECK(b_tensors.is_cuda());
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer.
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
// we will store the layout to an int32 tensor;
// this is the number of elements we need per layout
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
int num_experts = static_cast<int>(b_tensors.size(0));
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
auto b_packed_ptr = static_cast<QuantType*>(b_tensors_packed.data_ptr());
// multiply by ull so result does not overflow int32
size_t num_int4_elems = 1ull * num_experts * n * k;
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
num_int4_elems);
TORCH_CHECK(ok, "unified_encode_int4b failed");
// construct the layout once; assumes each expert has the same layout
using LayoutType = LayoutB_Reordered;
std::vector<LayoutType> layout_B_reordered_host(num_experts);
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, Int<1>{}});
auto shape_B = cute::make_shape(n, k, Int<1>{});
auto layout_B = make_layout(shape_B, stride_B);
LayoutType layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
// reorder weights for each expert
for (int i = 0; i < num_experts; i++) {
// since the storage type of int4b is 1 byte but one element is 4 bits
// we need to adjust the offset
int64_t offset =
1ull * i * n * k * cutlass::sizeof_bits<QuantType>::value / 8;
cutlass::reorder_tensor(b_packed_ptr + offset, layout_B,
layout_B_reordered);
}
// save the packed layout to torch tensor so we can re-use it
auto cpu_opts =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
torch::Tensor layout_cpu =
torch::empty({num_experts, layout_width}, cpu_opts);
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
for (int i = 0; i < num_experts; ++i) {
std::memcpy(layout_data + i * layout_width, // dst (int32*)
&layout_B_reordered, // src (LayoutType*)
sizeof(LayoutType)); // number of bytes
}
torch::Tensor packed_layout =
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
return {b_tensors_packed, packed_layout};
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", &mm);
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
}
} // namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -7,6 +7,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/all.h> #include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h" #include "core/registration.h"
@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return packed_scales; return packed_scales;
} }
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__ uint8_t kNibbleLUT[256];
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
size_t nbytes) {
constexpr size_t V = sizeof(uint4); // 16 bytes
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
const size_t nvec = nbytes / V;
// 1-D grid-stride loop over 16-byte chunks
for (size_t vec = tid; vec < nvec; vec += nthreads) {
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
#pragma unroll
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
reinterpret_cast<uint4*>(out)[vec] = v;
}
}
static bool upload_lut() {
std::array<uint8_t, 256> lut{};
auto map_nib = [](uint8_t v) -> uint8_t {
// 1..7 -> (8 - v); keep 0 and 8..15
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
};
for (int b = 0; b < 256; ++b) {
uint8_t lo = b & 0xF;
uint8_t hi = (b >> 4) & 0xF;
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
}
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
/*offset=*/0, cudaMemcpyHostToDevice);
return (e == cudaSuccess);
}
static bool unified_encode_int4b(cutlass::int4b_t const* in,
cutlass::int4b_t* out, size_t num_int4_elems) {
// Build/upload LUT
if (!upload_lut()) return false;
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
"int4 storage must be 1 byte");
const size_t nbytes = num_int4_elems >> 1;
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
// kernel launch params
constexpr int block = 256;
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
int grid = int((nvec + block - 1) / block);
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
cudaError_t err = cudaGetLastError();
return (err == cudaSuccess);
}
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
TORCH_CHECK(B.dtype() == torch::kInt32); TORCH_CHECK(B.dtype() == torch::kInt32);
TORCH_CHECK(B.dim() == 2); TORCH_CHECK(B.dim() == 2);
@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
LayoutB_Reordered layout_B_reordered = LayoutB_Reordered layout_B_reordered =
cute::tile_to_shape(LayoutAtomQuant{}, shape_B); cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
bool ok = bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); n * k);
TORCH_CHECK(ok, "unified_encode_int4b failed"); TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);

View File

@ -0,0 +1,90 @@
#include "w4a8_utils.cuh"
#include <array>
#include <cuda_runtime.h>
#include <cstdio>
namespace vllm::cutlass_w4a8_utils {
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__ uint8_t kNibbleLUT[256];
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
size_t nbytes) {
constexpr size_t V = sizeof(uint4); // 16 bytes
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
const size_t nvec = nbytes / V;
// 1-D grid-stride loop over 16-byte chunks
for (size_t vec = tid; vec < nvec; vec += nthreads) {
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
#pragma unroll
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
reinterpret_cast<uint4*>(out)[vec] = v;
}
}
static bool upload_lut() {
std::array<uint8_t, 256> lut{};
auto map_nib = [](uint8_t v) -> uint8_t {
// 1..7 -> (8 - v); keep 0 and 8..15
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
};
for (int b = 0; b < 256; ++b) {
uint8_t lo = b & 0xF;
uint8_t hi = (b >> 4) & 0xF;
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
}
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
/*offset=*/0, cudaMemcpyHostToDevice);
return (e == cudaSuccess);
}
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
size_t num_int4_elems) {
// Build/upload LUT
if (!upload_lut()) return false;
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
"int4 storage must be 1 byte");
const size_t nbytes = num_int4_elems >> 1;
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
// kernel launch params
constexpr int block = 256;
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
int grid = int((nvec + block - 1) / block);
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
// launch errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("unified_encode_int4b_device launch error: %s (%d)\n",
cudaGetErrorString(err), err);
return false;
}
// runtime errors
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("unified_encode_int4b_device runtime error: %s (%d)\n",
cudaGetErrorString(err), err);
return false;
}
return true;
}
} // namespace vllm::cutlass_w4a8_utils

View File

@ -0,0 +1,11 @@
#pragma once
#include <cstddef>
#include "cutlass/numeric_types.h"
namespace vllm::cutlass_w4a8_utils {
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
size_t num_int4_elems);
} // namespace vllm::cutlass_w4a8_utils

View File

@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
void get_cutlass_moe_mm_problem_sizes_caller( void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) { const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
// Swap-AB should be disabled for FP4 path // Swap-AB should be disabled for FP4 path
bool may_swap_ab = (!blockscale_offsets.has_value()) && bool may_swap_ab =
(topk_ids.numel() <= SWAP_AB_THRESHOLD); force_swap_ab.value_or((!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD));
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream, atomic_buffer, num_experts, n, k, stream,

View File

@ -80,7 +80,8 @@ void get_cutlass_moe_mm_data_caller(
void get_cutlass_moe_mm_problem_sizes_caller( void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets); const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes1,
@ -303,14 +304,15 @@ void get_cutlass_moe_mm_data(
void get_cutlass_moe_mm_problem_sizes( void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) { const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k, problem_sizes2, num_experts, n, k,
blockscale_offsets); blockscale_offsets, force_swap_ab);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(

View File

@ -350,6 +350,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
// conditionally compiled so impl registration is in source file
#endif #endif
// Dequantization for GGML. // Dequantization for GGML.
@ -466,7 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, " " Tensor! problem_sizes1, "
" Tensor! problem_sizes2, " " Tensor! problem_sizes2, "
" int num_experts, int n, int k, " " int num_experts, int n, int k, "
" Tensor? blockscale_offsets) -> ()"); " Tensor? blockscale_offsets, "
" bool? force_swap_ab) -> ()");
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes); &get_cutlass_moe_mm_problem_sizes);

View File

@ -12,8 +12,11 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_packed_uint4b8_to_signed_int4_inplace,
pack_cols,
pack_rows, pack_rows,
quantize_weights, quantize_weights,
unpack_quantized_values_into_int32,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
@ -167,8 +170,7 @@ def create_test_tensors(
# for the practical use case we need per-tok scales for fp8 activations # for the practical use case we need per-tok scales for fp8 activations
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type) w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
# weights are already per-group quantized, use placeholder here w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type)
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
return Tensors( return Tensors(
w_ref=w_ref, w_ref=w_ref,
@ -211,7 +213,7 @@ def mm_test_helper(
print(output_ref) print(output_ref)
torch.testing.assert_close( torch.testing.assert_close(
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3 output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2
) )
@ -257,7 +259,7 @@ def test_w4a8_cuda_graph():
) )
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32) w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32) w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32)
# Construct a trivial model with a single layer that calls the kernel # Construct a trivial model with a single layer that calls the kernel
model = W4A8Layer( model = W4A8Layer(
@ -287,4 +289,38 @@ def test_w4a8_cuda_graph():
output.zero_() output.zero_()
g.replay() g.replay()
torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2)
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES)
def test_convert_packed_uint4b8_to_signed_int4_inplace(shape):
"""
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
The CUTLASS kernels expect signed int4 packed to int32.
This tests checks that the runtime int4b8 -> signed int4 conversion
matches the offline conversion step exactly.
"""
_, N, K = shape
# random weights packed to int32
t = torch.randint(
low=torch.iinfo(torch.int32).min,
high=torch.iinfo(torch.int32).max + 1,
size=(N, K // 8),
dtype=torch.int32,
device="cuda",
)
# compute reference
unpacked = unpack_quantized_values_into_int32(
t.clone(), scalar_types.uint4b8, packed_dim=1
)
unpacked = unpacked - 8 # int4b8 -> signed int4
ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape)
out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone())
assert torch.equal(ref, out)
assert not torch.equal(ref, t)

View File

@ -0,0 +1,340 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
"""
import random
from dataclasses import dataclass
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows,
quantize_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
def cutlass_quantize(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: torch.dtype | None,
group_size: int | None,
zero_points: bool = False,
):
"""
Quantize weights into W4 and compute reference dequantized weights.
Encoding/reordering of weights and packing of scales is deferred
until after all experts are combined.
"""
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
w, wtype, group_size=group_size, zero_points=zero_points
)
# Since scales are later cast to fp8, recompute w_ref in atype here.
w_ref = (
w_q.to(torch.float32)
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
).to(atype)
# Bit mask prevents sign extension of int4 when packing.
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
# Make weights row-major (N, K).
w_q = w_q.t().contiguous()
return w_ref, w_q, w_s.to(atype), w_zp
def cutlass_preprocess(
w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor]
):
"""
Reorder/encode expert weights and pack scales.
Returns:
w_q_packed: Packed/encoded int4 weights for all experts.
w_s_packed: Packed fp8 scales for all experts.
packed_layout: Layout/stride metadata for grouped GEMM.
"""
w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts))
w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped(
torch.stack(w_q_experts)
) # expects dim 3
return w_q_packed, w_s_packed, packed_layout
GROUP_SIZE = 128
# (num_experts, N, K)
TEST_SHAPES = [
(8, 512, 2048),
(8, 2048, 2048),
(64, 512, 1024),
(64, 2048, 2048),
(4, 2048, 768),
(8, 768, 2048),
(64, 1536, 2048),
(128, 8192, 4096), # test overflow int32
]
ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check
@dataclass
class MoETestSetup:
num_experts: int
K: int
N: int
Ms: list[int]
M_full: int
a: torch.Tensor
a_ref: torch.Tensor
a_strides: torch.Tensor
out: torch.Tensor
c_strides: torch.Tensor
per_tok_scales: torch.Tensor
per_chan_scales: torch.Tensor
w_refs: list[torch.Tensor]
w_q_packed: torch.Tensor
w_s_packed: torch.Tensor
problem_sizes: torch.Tensor
expert_offsets: torch.Tensor
b_strides: torch.Tensor
group_scale_strides: torch.Tensor
def make_moe_test_setup(
num_experts: int,
K: int,
N: int,
*,
alignment: int = ALIGNMENT,
max_blocks: int = 64,
device: str = "cuda",
random_zero: bool = False,
) -> MoETestSetup:
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
assert K % GROUP_SIZE == 0
# Token counts per expert (multiples of `alignment`).
Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)]
# set random experts to 0 tokens
if random_zero and num_experts > 1:
num_zero = max(1, num_experts // 8)
zero_indices = random.sample(range(num_experts), k=num_zero)
for idx in zero_indices:
Ms[idx] = 0
M_full = sum(Ms)
assert M_full > 0
# Activations.
a = to_fp8(torch.randn((M_full, K), device=device))
a_ref = a.to(torch.float32)
a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device)
# Output buffer.
out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device)
c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device)
# Channel/token scales.
per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device)
per_chan_scales = torch.randn(
(num_experts, N, 1), dtype=torch.float32, device=device
)
# Expert weights and scales.
wtype = scalar_types.int4
atype = stype = torch.float8_e4m3fn
w_refs, w_qs, w_ss = [], [], []
for _ in range(num_experts):
b = to_fp8(torch.randn((K, N), device=device))
w_ref, w_q, w_s, _ = cutlass_quantize(
atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False
)
w_refs.append(w_ref)
w_qs.append(w_q)
w_ss.append(w_s)
w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss)
problem_sizes = torch.tensor(
[[N, M, K] for M in Ms], dtype=torch.int32, device=device
)
expert_offsets = torch.cat(
[
torch.tensor([0], dtype=torch.int64),
torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1],
]
).to(device=device)
# B strides and group scale strides.
b_strides = packed_layout
group_scale_strides = torch.zeros(
(num_experts, 2), dtype=torch.int64, device=device
)
group_scale_strides[:, 0] = N
return MoETestSetup(
num_experts=num_experts,
K=K,
N=N,
Ms=Ms,
M_full=M_full,
a=a,
a_ref=a_ref,
a_strides=a_strides,
out=out,
c_strides=c_strides,
per_tok_scales=per_tok_scales,
per_chan_scales=per_chan_scales,
w_refs=w_refs,
w_q_packed=w_q_packed,
w_s_packed=w_s_packed,
problem_sizes=problem_sizes,
expert_offsets=expert_offsets,
b_strides=b_strides,
group_scale_strides=group_scale_strides,
)
def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor:
"""Compute reference output using torch._scaled_mm per expert."""
out_ref = torch.empty_like(setup.out)
ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist()
starts = setup.expert_offsets.cpu().tolist()
for i in range(setup.num_experts):
start, end = starts[i], ends[i]
if start == end:
continue
out_ref_i = torch._scaled_mm(
setup.a_ref[start:end].to(torch.float8_e4m3fn),
setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(),
setup.per_tok_scales[start:end], # (M, 1)
setup.per_chan_scales[i].reshape(1, -1), # (1, N)
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
out_ref[start:end] = out_ref_i
return out_ref
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU,
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
)
@pytest.mark.parametrize("shape", TEST_SHAPES)
@pytest.mark.parametrize("random_zero", [True, False])
def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
num_experts, N, K = shape
current_platform.seed_everything(42)
setup = make_moe_test_setup(
num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero
)
ops.cutlass_w4a8_moe_mm(
setup.out,
setup.a,
setup.w_q_packed,
setup.per_tok_scales,
setup.per_chan_scales,
setup.w_s_packed,
GROUP_SIZE,
setup.expert_offsets,
setup.problem_sizes,
setup.a_strides,
setup.b_strides,
setup.c_strides,
setup.group_scale_strides,
)
torch.cuda.synchronize()
out_ref = compute_moe_reference_output(setup)
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
class W4A8MoELayer(torch.nn.Module):
"""
Minimal wrapper module to test cuda graphs
"""
def __init__(self, setup: MoETestSetup):
super().__init__()
self.setup = setup
def forward(self, a: torch.Tensor) -> torch.Tensor:
s = self.setup
ops.cutlass_w4a8_moe_mm(
s.out,
a,
s.w_q_packed,
s.per_tok_scales,
s.per_chan_scales,
s.w_s_packed,
GROUP_SIZE,
s.expert_offsets,
s.problem_sizes,
s.a_strides,
s.b_strides,
s.c_strides,
s.group_scale_strides,
)
return s.out
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU,
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
)
def test_cutlass_w4a8_moe_mm_cuda_graph():
current_platform.seed_everything(42)
# Fixed config for CUDA graph test (single parameter point).
num_experts = 8
K = 512
N = 2048
setup = make_moe_test_setup(
num_experts=num_experts,
K=K,
N=N,
max_blocks=32,
)
# Construct model that calls the grouped GEMM kernel.
model = W4A8MoELayer(setup)
# Build reference output once.
out_ref = compute_moe_reference_output(setup)
# Capture and run the model in a CUDA graph.
a_static = setup.a.clone() # static input tensor for graph replay
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_static = model(a_static)
out_static.zero_()
g.replay()
torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2)

View File

@ -695,6 +695,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(b, memory_format=torch.contiguous_format) return torch.empty_like(b, memory_format=torch.contiguous_format)
@register_fake("_C::cutlass_encode_and_reorder_int4b_grouped")
def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(b, memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "allspark_w8a16_gemm"): if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
@ -1058,6 +1062,7 @@ def get_cutlass_moe_mm_problem_sizes(
n: int, n: int,
k: int, k: int,
blockscale_offsets: torch.Tensor | None = None, blockscale_offsets: torch.Tensor | None = None,
force_swap_ab: bool | None = None,
): ):
""" """
Compute only the per-expert problem sizes needed by the two grouped matrix Compute only the per-expert problem sizes needed by the two grouped matrix
@ -1067,9 +1072,20 @@ def get_cutlass_moe_mm_problem_sizes(
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's - problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs multiplication for the two grouped MMs
used in the fused MoE operation. used in the fused MoE operation.
Optional:
- force_swap_ab: If set to True or False, explicitly enable or disable the
A/B input swap optimization. If None (default), the swap
is selected automatically based on tensor sizes.
""" """
return torch.ops._C.get_cutlass_moe_mm_problem_sizes( return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets topk_ids,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
blockscale_offsets,
force_swap_ab,
) )
@ -1457,6 +1473,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return torch.ops._C.cutlass_encode_and_reorder_int4b(b) return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
def cutlass_w4a8_moe_mm(
out_tensors: torch.Tensor,
a_tensors: torch.Tensor,
b_tensors: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
b_group_scales: torch.Tensor,
b_group_size: int,
expert_offsets: torch.Tensor,
problem_sizes: torch.Tensor,
a_strides: torch.Tensor,
b_strides: torch.Tensor,
c_strides: torch.Tensor,
group_scale_strides: torch.Tensor,
maybe_schedule: str | None = None,
):
"""
Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
and both per-channel + per-token scaling in the epilogue.
Args:
out_tensors:
Output buffer for all experts (updated in-place).
a_tensors:
FP8 (E4M3FN) activations for all experts.
b_tensors:
INT4-packed weight matrix for all experts, packed to INT32
a_scales:
Per-token FP8 activation scales, applied in the epilogue.
b_scales:
Per-channel FP8 weight scales for each expert, applied in the epilogue.
b_group_scales:
FP8 scale values for group-wise INT4 weight blocks.
b_group_size:
Number of elements grouped under each entry of b_group_scales.
expert_offsets:
Cumulative token offsets
problem_sizes:
Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
a/b/c/group_scale_strides:
Strides describing the memory layout of the input tensors.
maybe_schedule:
Optional override to choose a specific kernel or epilogue schedule.
Returns:
out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
"""
return torch.ops._C.cutlass_w4a8_moe_mm(
out_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
b_group_scales,
b_group_size,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
c_strides,
group_scale_strides,
maybe_schedule,
)
def cutlass_encode_and_reorder_int4b_grouped(
b_tensors: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors)
if hasattr(torch.ops._C, "permute_cols"): if hasattr(torch.ops._C, "permute_cols"):
@register_fake("_C::permute_cols") @register_fake("_C::permute_cols")

View File

@ -63,8 +63,10 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8, CutlassBatchedExpertsFp8,
CutlassExpertsFp8, CutlassExpertsFp8,
CutlassExpertsW4A8Fp8,
cutlass_moe_fp4, cutlass_moe_fp4,
cutlass_moe_fp8, cutlass_moe_fp8,
cutlass_moe_w4a8_fp8,
) )
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
@ -88,8 +90,10 @@ if HAS_TRITON:
"grouped_topk", "grouped_topk",
"cutlass_moe_fp8", "cutlass_moe_fp8",
"cutlass_moe_fp4", "cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8", "CutlassExpertsFp8",
"CutlassBatchedExpertsFp8", "CutlassBatchedExpertsFp8",
"CutlassExpertsW4A8Fp8",
"TritonExperts", "TritonExperts",
"BatchedTritonExperts", "BatchedTritonExperts",
"DeepGemmExperts", "DeepGemmExperts",

View File

@ -143,6 +143,7 @@ class FusedMoEQuantDesc:
scale: Union[torch.Tensor, "PrecisionConfig", None] = None scale: Union[torch.Tensor, "PrecisionConfig", None] = None
# Quantization alphas or gscales, used for nvfp4 types. # Quantization alphas or gscales, used for nvfp4 types.
# W4A8 FP8: used for per-channel scales
# TODO(bnell): put some of these in subclasses # TODO(bnell): put some of these in subclasses
alpha_or_gscale: torch.Tensor | None = None alpha_or_gscale: torch.Tensor | None = None
@ -442,7 +443,9 @@ class FusedMoEQuantConfig:
- a1_scale: Optional scale to be used for a1. - a1_scale: Optional scale to be used for a1.
- a2_scale: Optional scale to be used for a2. - a2_scale: Optional scale to be used for a2.
- g1_alphas: Optional global quantization scales for w1 (for nvfp4). - g1_alphas: Optional global quantization scales for w1 (for nvfp4).
per-channel scales for w1 (for W4A8 FP8).
- g2_alphas: Optional global quantization scales for w2 (for nvfp4). - g2_alphas: Optional global quantization scales for w2 (for nvfp4).
per-channel scales for w2 (for W4A8 FP8).
- a1_gscale: Optional global quantization scales for a1 (for nvfp4). - a1_gscale: Optional global quantization scales for a1 (for nvfp4).
- a2_gscale: Optional global quantization scales for a2 (for nvfp4). - a2_gscale: Optional global quantization scales for a2 (for nvfp4).
- w1_bias: Optional biases for w1 (GPT OSS Triton). - w1_bias: Optional biases for w1 (GPT OSS Triton).
@ -461,6 +464,7 @@ class FusedMoEQuantConfig:
"mxfp4", "mxfp4",
"mxfp6_e3m2", "mxfp6_e3m2",
"mxfp6_e2m3", "mxfp6_e2m3",
"int4",
} }
if weight_dtype is None: if weight_dtype is None:
@ -671,6 +675,31 @@ def int8_w8a16_moe_quant_config(
) )
def int4_w4afp8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and int4 weights.
"""
return FusedMoEQuantConfig.make(
torch.float8_e4m3fn, # quant dtype for activations
w1_scale=w1_scale,
w2_scale=w2_scale,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
weight_dtype="int4", # weight dtype for weights
)
def biased_moe_quant_config( def biased_moe_quant_config(
w1_bias: torch.Tensor | None, w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None, w2_bias: torch.Tensor | None,

View File

@ -1052,3 +1052,404 @@ def run_cutlass_block_scaled_fused_experts(
return ( return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1) ).sum(dim=1)
# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation_callable: Callable,
global_num_experts: int,
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
w1_chan_scale: torch.Tensor,
w2_chan_scale: torch.Tensor,
a_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides1: torch.Tensor,
b_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: torch.Tensor | None,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool,
topk_weights: torch.Tensor | None,
group_size: int,
):
a1q = hidden_states
M = a1q.size(0)
local_E = w1.size(0)
device = a1q.device
_, K, N_packed = w2.shape
N = N_packed * 8 # logical N, pack 8 int4 into 1 int32
assert per_act_token, "W4A8 must use per-token scales"
assert per_out_ch, "W4A8 must use per-channel scales"
assert w1_scale is not None
assert w2_scale is not None
assert w1_scale.dtype == torch.float8_e4m3fn
assert w2_scale.dtype == torch.float8_e4m3fn
assert w1.dtype == torch.int32
assert w2.dtype == torch.int32
assert w1_chan_scale.dtype == torch.float32
assert w2_chan_scale.dtype == torch.float32
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
assert a1q_scale is not None
assert a2_scale is None
assert out_dtype in [torch.bfloat16], f"Invalid output dtype: {out_dtype}"
if expert_map is not None:
assert expert_num_tokens is None
assert not use_batched_format, "batched format not supported yet"
assert group_size == 128, f"Only group size 128 supported but got {group_size=}"
assert global_num_experts != -1
assert w1.size(2) * 8 == K, (
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
)
# Translate info from expert_map to topk_ids
if expert_map is not None:
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.size(1)
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N)
)
mm2_out = _resize_cache(workspace2, (M * topk, K))
problem_sizes1 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a1q,
a1q_scale,
topk_ids,
num_expert,
local_E,
expert_map,
permuted_hidden_states=a1q_perm,
)
expert_offsets = expert_offsets[:-1]
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids,
problem_sizes1,
problem_sizes2,
global_num_experts,
N,
K,
force_swap_ab=True,
)
ops.cutlass_w4a8_moe_mm(
mm1_out,
a1q,
w1,
a1q_scale,
w1_chan_scale,
w1_scale,
group_size,
expert_offsets,
problem_sizes1,
a_strides1,
b_strides1,
c_strides1,
s_strides1,
)
activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
)
if expert_map is not None:
mm2_out.fill_(0)
ops.cutlass_w4a8_moe_mm(
mm2_out,
a2q,
w2,
a2q_scale,
w2_chan_scale,
w2_scale,
group_size,
expert_offsets,
problem_sizes2,
a_strides2,
b_strides2,
c_strides2,
s_strides2,
)
# for non-chunking mode the output is resized from workspace13
# so we need to make sure mm2_out uses workspace2.
moe_unpermute(
out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm,
)
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: torch.dtype | None,
a_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides1: torch.Tensor,
b_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
group_size: int,
):
super().__init__(quant_config)
self.out_dtype = out_dtype
self.a_strides1 = a_strides1
self.a_strides2 = a_strides2
self.b_strides1 = b_strides1
self.b_strides2 = b_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.s_strides1 = s_strides1
self.s_strides2 = s_strides2
self.group_size = group_size
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
)
assert not use_batched_format, "batched format not supported"
in_dtype = hidden_states.dtype
run_cutlass_moe_w4a8_fp8(
output,
hidden_states,
w1,
w2,
topk_ids,
activation_callable,
global_num_experts,
expert_map,
self.w1_scale,
self.w2_scale,
a1q_scale,
a2_scale,
self.g1_alphas, # per-channel scales
self.g2_alphas, # per-channel scales
self.a_strides1,
self.a_strides2,
self.b_strides1,
self.b_strides2,
self.c_strides1,
self.c_strides2,
self.s_strides1,
self.s_strides2,
workspace13,
workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant,
self.per_out_ch_quant,
use_batched_format,
topk_weights,
self.group_size,
)
def cutlass_moe_w4a8_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides1: torch.Tensor,
b_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
group_size: int = 128,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
mixed-dtype grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, 2*N, K // packed_factor]
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, K, N // packed_factor]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- a_strides1 (torch.Tensor): The input strides for the first gemm.
Shape: [num_experts]
- a_strides2 (torch.Tensor): The input strides for the second gemm.
Shape: [num_experts]
- b_strides1 (torch.Tensor): The packed layout for the first gemm weights.
Shape: [num_experts, 3]
dtype: torch.int32
- b_strides2 (torch.Tensor): The packed layout for the second gemm weights.
Shape: [num_experts, 3]
dtype: torch.int32
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm.
Shape: [num_experts, 2]
dtype: torch.int64
- s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm.
Shape: [num_experts, 2]
dtype: torch.int64
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
- group_size (int): The number of weights per scale factor
Returns:
- torch.Tensor: The bf16 output tensor after applying the MoE layer.
"""
assert quant_config is not None
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsW4A8Fp8(
out_dtype=a.dtype,
a_strides1=a_strides1,
a_strides2=a_strides2,
b_strides1=b_strides1,
b_strides2=b_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
s_strides1=s_strides1,
s_strides2=s_strides2,
quant_config=quant_config,
group_size=group_size,
),
)
return fn(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -256,7 +256,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if format is not None if format is not None
else is_activation_quantization_format(quant_format) else is_activation_quantization_format(quant_format)
) )
# TODO(czhu): w4a8fp8 is in packed-quantized format # w4a8fp8 is in packed-quantized format
# but needs input activation quantization # but needs input activation quantization
input_activations = quant_config.get("input_activations") input_activations = quant_config.get("input_activations")
if act_quant_format or input_activations: if act_quant_format or input_activations:

View File

@ -33,6 +33,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config, int8_w8a16_moe_quant_config,
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
@ -79,7 +80,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, all_close_1d,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
@ -204,6 +209,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsW8A8Int8MoEMethod( return CompressedTensorsW8A8Int8MoEMethod(
weight_quant, input_quant, layer.moe_config weight_quant, input_quant, layer.moe_config
) )
elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod")
return CompressedTensorsW4A8Fp8MoEMethod(
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod( return CompressedTensorsW4A8Int8MoEMethod(
weight_quant, input_quant, layer.moe_config weight_quant, input_quant, layer.moe_config
@ -2428,3 +2438,331 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input, apply_router_weight_on_input,
int(_act_kind(activation)), int(_act_kind(activation)),
) )
class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
self.group_size = self.weight_quant.group_size
self.num_bits = self.weight_quant.num_bits
self.packed_factor = 32 // self.num_bits
assert self.weight_quant.symmetric, (
"Only symmetric quantization is supported for W4A8 MoE"
)
assert self.weight_quant.actorder != "group"
assert self.group_size == 128, "Only group size 128 supported for W4A8 MoE"
self.disable_expert_map = False
self.layer_name = layer_name
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# requirement for CUTLASS reorder_tensor
assert hidden_size % 256 == 0, f"{hidden_size=} must be divisible by 256"
assert intermediate_size_per_partition % 256 == 0, (
f"{intermediate_size_per_partition=} must be divisible by 256"
)
# storage type, pack 8xint4 into int32
params_dtype = torch.int32
# WEIGHTS
w13_weight_packed = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // self.packed_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_packed", w13_weight_packed)
set_weight_attrs(w13_weight_packed, extra_weight_attrs)
w2_weight_packed = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // self.packed_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_packed", w2_weight_packed)
set_weight_attrs(w2_weight_packed, extra_weight_attrs)
# SCALES
# weight_scale refers to the group-wise scales
# they are initially loaded as bf16, we will convert to fp8
# after loading
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // self.group_size,
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition // self.group_size,
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-GROUP quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# weight shapes
w2_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
# don't use input scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer):
device = layer.w13_weight_packed.device
# STRIDES
# A, C
self.a_strides1_c_strides2 = torch.full(
(layer.local_num_experts,),
layer.hidden_size,
device=device,
dtype=torch.int64,
)
self.a_strides2 = torch.full(
(layer.local_num_experts,),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(layer.local_num_experts,),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
# S (group-wise scales)
# sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it
self.s_strides1 = torch.zeros(
(layer.local_num_experts, 2), device=device, dtype=torch.int64
)
self.s_strides1[:, 0] = 2 * layer.intermediate_size_per_partition
self.s_strides2 = torch.zeros(
(layer.local_num_experts, 2), device=device, dtype=torch.int64
)
self.s_strides2[:, 0] = layer.hidden_size
# encode and reorder weight tensors, and get the layout to pass to
# the grouped gemm kernel. `b_strides1/2` specifies the entire layout
convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed)
w13_weight_shuffled, self.b_strides1 = (
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed)
)
replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled)
convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed)
w2_weight_shuffled, self.b_strides2 = (
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed)
)
replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled)
# convert bf16 scales to (fp8_scales, channel_scales)
w13_weight_scale, w13_weight_chan_scale = convert_bf16_scales_to_fp8(
self.quant_fp8, layer.w13_weight_scale
)
w2_weight_scale, w2_weight_chan_scale = convert_bf16_scales_to_fp8(
self.quant_fp8, layer.w2_weight_scale
)
# register channel scales
layer.register_parameter(
"w13_weight_chan_scale",
torch.nn.Parameter(w13_weight_chan_scale, requires_grad=False),
)
layer.register_parameter(
"w2_weight_chan_scale",
torch.nn.Parameter(w2_weight_chan_scale, requires_grad=False),
)
# The scales are stored as (E, N, K // 128) but the kernel expects
# (E, K // 128, N) in row-major format, so we need to permute the last 2 dims
# and make it contiguous
w13_weight_scale_packed = ops.cutlass_pack_scale_fp8(
w13_weight_scale.permute(0, 2, 1).contiguous()
)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale_packed)
w2_weight_scale_packed = ops.cutlass_pack_scale_fp8(
w2_weight_scale.permute(0, 2, 1).contiguous()
)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale_packed)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
return super().maybe_make_prepare_finalize(routing_tables)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# Store quantization scales; both per-group and per-channel
# Note we haven't specified the group size here because
# the quant config logic assumes group-wise scaling
# and channel-wise scaling are exclusive.
return int4_w4afp8_moe_quant_config(
w1_scale=layer.w13_weight_scale, # group scale
w2_scale=layer.w2_weight_scale, # group scale
g1_alphas=layer.w13_weight_chan_scale,
g2_alphas=layer.w2_weight_chan_scale,
per_act_token_quant=True, # always use dynamc per-token
per_out_ch_quant=True, # always use per-channel
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
assert (
prepare_finalize.activation_format == FusedMoEActivationFormat.Standard
), "BatchedExperts not supported"
from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8
experts: FusedMoEPermuteExpertsUnpermute
logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__)
experts = CutlassExpertsW4A8Fp8(
out_dtype=self.moe.in_dtype,
a_strides1=self.a_strides1_c_strides2,
a_strides2=self.a_strides2,
b_strides1=self.b_strides1,
b_strides2=self.b_strides2,
c_strides1=self.c_strides1,
c_strides2=self.a_strides1_c_strides2,
s_strides1=self.s_strides1,
s_strides2=self.s_strides2,
quant_config=self.moe_quant_config,
group_size=self.group_size,
)
num_dispatchers = prepare_finalize.num_dispatchers()
self.disable_expert_map = (
num_dispatchers > 1 or not experts.supports_expert_map()
)
return experts
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
)
assert self.moe_quant_config is not None
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_w4a8_fp8,
)
return cutlass_moe_w4a8_fp8(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
topk_weights,
topk_ids,
quant_config=self.moe_quant_config,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
a_strides1=self.a_strides1_c_strides2,
a_strides2=self.a_strides2,
b_strides1=self.b_strides1,
b_strides2=self.b_strides2,
c_strides1=self.c_strides1,
c_strides2=self.a_strides1_c_strides2,
s_strides1=self.s_strides1,
s_strides2=self.s_strides2,
group_size=self.group_size,
)
@property
def supports_eplb(self) -> bool:
return False

View File

@ -128,14 +128,15 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
), ),
) )
# TODO(czhu): allocate the packed fp8 scales memory here? # After loading, we will transform bf16 -> fp8 ->
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8` # expand by 8x via `cutlass_pack_scale_fp8`
# and construct per-channel fp32 scales.
weight_scale_args = { weight_scale_args = {
"weight_loader": weight_loader, "weight_loader": weight_loader,
"data": torch.empty( "data": torch.empty(
output_size_per_partition, output_size_per_partition,
scales_and_zp_size, scales_and_zp_size,
dtype=torch.float8_e4m3fn, dtype=params_dtype,
), ),
} }
@ -152,17 +153,9 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
) )
# per-channel scales
weight_chan_scale = ChannelQuantScaleParameter(
data=torch.empty((output_size_per_partition, 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_chan_scale", weight_chan_scale)
self.kernel = kernel_type( self.kernel = kernel_type(
mp_linear_kernel_config, mp_linear_kernel_config,

View File

@ -6,7 +6,11 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
"CUTLASS W4A8, only supported int4", "CUTLASS W4A8, only supported int4",
) )
# TODO(czhu): support -1 (column-wise)
if c.group_size != 128: if c.group_size != 128:
return False, "Only group_size 128 is supported" return False, "Only group_size 128 is supported"
@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1} # `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module): def process_weights_after_loading(self, layer: torch.nn.Module):
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x): def transform_w_q(x):
assert isinstance(x, BasevLLMParameter) assert isinstance(x, BasevLLMParameter)
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t()) x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x return x
@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
x.data = ops.cutlass_pack_scale_fp8(x.data) x.data = ops.cutlass_pack_scale_fp8(x.data)
return x return x
w_s = getattr(layer, self.w_s_name)
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
w_s.data = fp8_scales
# register per-channel scales
layer.register_parameter(
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
)
# Encode/reorder weights and pack scales # Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, self.w_s_name, transform_w_s)
self._transform_param(layer, "weight_chan_scale", lambda x: x)
def apply_weights( def apply_weights(
self, self,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from collections.abc import Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from types import MappingProxyType from types import MappingProxyType
from typing import ClassVar, NamedTuple from typing import ClassVar, NamedTuple
@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int() capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability) return cutlass_scaled_mm_supports_fp4(capability)
def convert_bf16_scales_to_fp8(
quant_fp8: Callable, scales: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
expected by W4A8 GEMM kernels.
"""
assert scales.is_contiguous(), (
f"scale tensor must be contiguous, got {scales.stride()=}"
)
assert scales.is_cuda, "scales must be on gpu"
orig_shape = scales.shape
k_groups = orig_shape[-1]
flat_scales = scales.view(-1, k_groups)
fp8_scales, chan_scales = quant_fp8(flat_scales)
fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn)
chan_scales *= 8.0
# restore original shape
fp8_scales = fp8_scales.view(orig_shape)
chan_scales = chan_scales.view(orig_shape[:-1], -1)
return fp8_scales, chan_scales
def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor:
"""
Convert int4b8 (packed to int32) to signed int4
"""
assert t.is_cuda, "tensor must be on gpu"
assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}"
# loop through the 8 4-bit nibbles in each int32 entry
for i in range(8):
shift = 4 * i
# extract the i-th 4-bit nibble
nib = (t >> shift) & 0xF
# clear the original nibble by masking out
t &= ~(0xF << shift)
# convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
# and update in-place
t |= ((nib - 8) & 0xF) << shift
return t