mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
parent
ea657f2078
commit
f6227c22ab
@ -874,7 +874,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu")
|
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
|
||||||
|
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
|
||||||
|
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
|
||||||
|
)
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
|
|||||||
@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
|
|||||||
void get_cutlass_moe_mm_problem_sizes(
|
void get_cutlass_moe_mm_problem_sizes(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||||
|
std::optional<bool> force_swap_ab = std::nullopt);
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& problem_sizes1,
|
||||||
|
|||||||
104
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
Normal file
104
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
#include "cutlass/bfloat16.h"
|
||||||
|
#include "cutlass/float8.h"
|
||||||
|
|
||||||
|
// ElementB is int32 (packed int4)
|
||||||
|
// ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
|
||||||
|
template <typename ElementA, typename ElementB, typename ElementC,
|
||||||
|
typename ElementAccumulator, typename ElementGroupScale>
|
||||||
|
__global__ void get_group_gemm_starts(
|
||||||
|
int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets,
|
||||||
|
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||||
|
ElementAccumulator** b_scales_offsets,
|
||||||
|
ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int,
|
||||||
|
ElementB* b_base_as_int, ElementC* out_base_as_int,
|
||||||
|
ElementAccumulator* a_scales_base_as_int,
|
||||||
|
ElementAccumulator* b_scales_base_as_int,
|
||||||
|
ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k,
|
||||||
|
int64_t scale_k) {
|
||||||
|
int expert_id = threadIdx.x;
|
||||||
|
|
||||||
|
int64_t expert_offset = expert_offsets[expert_id];
|
||||||
|
|
||||||
|
// same as w8a8
|
||||||
|
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||||
|
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||||
|
a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset;
|
||||||
|
b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id);
|
||||||
|
|
||||||
|
// w4a8 specific
|
||||||
|
constexpr int pack_factor = 8; // pack 8 int4 into int32
|
||||||
|
b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor);
|
||||||
|
b_group_scales_offsets[expert_id] =
|
||||||
|
b_group_scales_base_as_int + (expert_id * scale_k * n);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||||
|
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||||
|
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
|
||||||
|
cutlass::Array<cutlass::float_e4m3_t, 8>> \
|
||||||
|
<<<1, num_experts, 0, stream>>>( \
|
||||||
|
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||||
|
static_cast<int32_t**>(b_ptrs.data_ptr()), \
|
||||||
|
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||||
|
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||||
|
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||||
|
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>**>( \
|
||||||
|
b_group_scales_ptrs.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||||
|
static_cast<int32_t*>(b_tensors.data_ptr()), \
|
||||||
|
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||||
|
static_cast<float*>(a_scales.data_ptr()), \
|
||||||
|
static_cast<float*>(b_scales.data_ptr()), \
|
||||||
|
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( \
|
||||||
|
b_group_scales.data_ptr()), \
|
||||||
|
n, k, scale_k); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void run_get_group_gemm_starts(
|
||||||
|
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||||
|
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||||
|
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||||
|
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
|
||||||
|
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
|
||||||
|
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_group_scales.dtype() ==
|
||||||
|
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
|
||||||
|
TORCH_CHECK(out_tensors.dtype() ==
|
||||||
|
torch::kBFloat16); // only support bf16 for now
|
||||||
|
// expect int64_t to avoid overflow during offset calculations
|
||||||
|
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||||
|
|
||||||
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
// logical k, n
|
||||||
|
int64_t n = out_tensors.size(1);
|
||||||
|
int64_t k = a_tensors.size(1);
|
||||||
|
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||||
|
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
483
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
Normal file
483
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
Normal file
@ -0,0 +1,483 @@
|
|||||||
|
#include <vector>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/packed_stride.hpp"
|
||||||
|
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||||
|
|
||||||
|
// vllm includes
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
|
#include "get_group_starts.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
#include "w4a8_utils.cuh"
|
||||||
|
|
||||||
|
namespace vllm::cutlass_w4a8_moe {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------
|
||||||
|
// Static configuration shared across all instantiations
|
||||||
|
// -------------------------------------------------------------------------------------
|
||||||
|
using ProblemShape =
|
||||||
|
cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per
|
||||||
|
// group
|
||||||
|
using MmaType = cutlass::float_e4m3_t;
|
||||||
|
using QuantType = cutlass::int4b_t;
|
||||||
|
|
||||||
|
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||||
|
static int constexpr PackFactor = 8; // 8 int4 packed into int32
|
||||||
|
|
||||||
|
// A matrix configuration
|
||||||
|
using ElementA = MmaType;
|
||||||
|
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||||
|
constexpr int AlignmentA =
|
||||||
|
128 /
|
||||||
|
cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of
|
||||||
|
// elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// B matrix configuration
|
||||||
|
using ElementB = QuantType; // Element type for B matrix operand
|
||||||
|
using LayoutB =
|
||||||
|
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||||
|
constexpr int AlignmentB =
|
||||||
|
128 / cutlass::sizeof_bits<
|
||||||
|
ElementB>::value; // Memory access granularity/alignment of B
|
||||||
|
// matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// This example manually swaps and transposes, so keep transpose of input
|
||||||
|
// layouts
|
||||||
|
using LayoutA_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||||
|
using LayoutB_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||||
|
|
||||||
|
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
|
||||||
|
using StrideA =
|
||||||
|
cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||||
|
using StrideB =
|
||||||
|
cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||||
|
|
||||||
|
// Define the CuTe layout for reoredered quantized tensor B
|
||||||
|
// LayoutAtomQuant places values that will be read by the same thread in
|
||||||
|
// contiguous locations in global memory. It specifies the reordering within a
|
||||||
|
// single warp's fragment
|
||||||
|
using LayoutAtomQuant =
|
||||||
|
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||||
|
using LayoutB_Reordered = decltype(cute::tile_to_shape(
|
||||||
|
LayoutAtomQuant{}, Layout<Shape<int, int, Int<1>>, StrideB>{}));
|
||||||
|
|
||||||
|
using ElementScale = cutlass::float_e4m3_t;
|
||||||
|
using LayoutScale = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
|
// C/D matrix configuration
|
||||||
|
using ElementC =
|
||||||
|
cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||||
|
using LayoutC =
|
||||||
|
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||||
|
constexpr int AlignmentC =
|
||||||
|
128 / cutlass::sizeof_bits<
|
||||||
|
ElementC>::value; // Memory access granularity/alignment of C
|
||||||
|
// matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// D matrix configuration
|
||||||
|
using ElementD = ElementC;
|
||||||
|
using LayoutD = LayoutC;
|
||||||
|
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
// Core kernel configurations
|
||||||
|
using ElementAccumulator = float; // Element type for internal accumulation
|
||||||
|
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||||
|
// supports the intended feature
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||||
|
using StageCountType =
|
||||||
|
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
|
||||||
|
// on the tile size
|
||||||
|
|
||||||
|
// per-channel and per-token scales for epilogue
|
||||||
|
using ElementSChannel = float;
|
||||||
|
|
||||||
|
template <class TileShape_MN, class ClusterShape_MNK, class KernelSchedule,
|
||||||
|
class EpilogueSchedule>
|
||||||
|
struct W4A8GroupedGemmKernel {
|
||||||
|
using TileShape =
|
||||||
|
decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{}));
|
||||||
|
using ClusterShape = ClusterShape_MNK;
|
||||||
|
|
||||||
|
// per-channel, per-token scales epilogue
|
||||||
|
using ChTokScalesEpilogue =
|
||||||
|
typename vllm::c3x::ScaledEpilogueArray<ElementAccumulator, ElementD,
|
||||||
|
TileShape>;
|
||||||
|
using EVTCompute = typename ChTokScalesEpilogue::EVTCompute;
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||||
|
ElementSChannel, ElementC,
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutC>::type*, AlignmentC,
|
||||||
|
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type*,
|
||||||
|
AlignmentD, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
|
// =========================================================== MIXED INPUT
|
||||||
|
// WITH SCALES
|
||||||
|
// ===========================================================================
|
||||||
|
// The Scale information must get paired with the operand that will be scaled.
|
||||||
|
// In this example, B is scaled so we make a tuple of B's information and the
|
||||||
|
// scale information.
|
||||||
|
using CollectiveMainloopShuffled =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass,
|
||||||
|
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
|
||||||
|
LayoutB_Reordered*, AlignmentB, ElementA, LayoutA_Transpose*,
|
||||||
|
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
ProblemShape, CollectiveMainloopShuffled, CollectiveEpilogue>;
|
||||||
|
|
||||||
|
using GemmShuffled =
|
||||||
|
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||||
|
|
||||||
|
using StrideC = typename GemmKernelShuffled::InternalStrideC;
|
||||||
|
using StrideD = typename GemmKernelShuffled::InternalStrideD;
|
||||||
|
|
||||||
|
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||||
|
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||||
|
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||||
|
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||||
|
|
||||||
|
// static asserts for passing in strides/layouts
|
||||||
|
// pack to 2x int64
|
||||||
|
static_assert(sizeof(StrideS) == 2 * sizeof(int64_t));
|
||||||
|
// pack to 3xint32,
|
||||||
|
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
|
||||||
|
"LayoutB_Reordered size must be divisible by 4 bytes");
|
||||||
|
|
||||||
|
static void grouped_mm(
|
||||||
|
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||||
|
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||||
|
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||||
|
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
|
||||||
|
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||||
|
const torch::Tensor& group_scale_strides) {
|
||||||
|
auto device = a_tensors.device();
|
||||||
|
auto device_id = device.index();
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(device_id);
|
||||||
|
|
||||||
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
int n = static_cast<int>(b_tensors.size(1));
|
||||||
|
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
|
||||||
|
|
||||||
|
auto options_int =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt64).device(device);
|
||||||
|
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
|
||||||
|
// get the correct offsets to pass to gemm
|
||||||
|
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||||
|
a_scales_ptrs, b_scales_ptrs, b_group_scales_ptrs,
|
||||||
|
a_tensors, b_tensors, out_tensors, a_scales,
|
||||||
|
b_scales, b_group_scales, b_group_size);
|
||||||
|
|
||||||
|
// construct args
|
||||||
|
using Args = typename GemmShuffled::Arguments;
|
||||||
|
using MainloopArguments = typename GemmKernelShuffled::MainloopArguments;
|
||||||
|
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
||||||
|
Args arguments;
|
||||||
|
|
||||||
|
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||||
|
static_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||||
|
problem_sizes_torch.data_ptr());
|
||||||
|
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||||
|
|
||||||
|
// SwapAB so B operands come first
|
||||||
|
MainloopArguments mainloop_arguments{
|
||||||
|
static_cast<const QuantType**>(b_ptrs.data_ptr()),
|
||||||
|
static_cast<LayoutB_Reordered*>(b_strides.data_ptr()),
|
||||||
|
static_cast<const MmaType**>(a_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||||
|
static_cast<const cutlass::Array<ElementScale, 8>**>(
|
||||||
|
b_group_scales_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideS*>(group_scale_strides.data_ptr()),
|
||||||
|
static_cast<int>(b_group_size)};
|
||||||
|
|
||||||
|
EpilogueArguments epilogue_arguments{
|
||||||
|
// since we are doing SwapAB the channel scales comes first, then token
|
||||||
|
// scales
|
||||||
|
ChTokScalesEpilogue::prepare_args( // see ScaledEpilogueArray
|
||||||
|
static_cast<const ElementAccumulator**>(
|
||||||
|
b_scales_ptrs.data_ptr()), // per-channel
|
||||||
|
static_cast<const ElementAccumulator**>(
|
||||||
|
a_scales_ptrs.data_ptr()), // per-token
|
||||||
|
true, true),
|
||||||
|
nullptr, // C
|
||||||
|
static_cast<StrideC*>(c_strides.data_ptr()), // C
|
||||||
|
static_cast<ElementD**>(out_ptrs.data_ptr()), // D
|
||||||
|
static_cast<StrideC*>(c_strides.data_ptr()) // D
|
||||||
|
};
|
||||||
|
|
||||||
|
static const cutlass::KernelHardwareInfo hw_info{
|
||||||
|
device_id,
|
||||||
|
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||||
|
device_id)};
|
||||||
|
|
||||||
|
arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape,
|
||||||
|
mainloop_arguments, epilogue_arguments, hw_info};
|
||||||
|
|
||||||
|
// Allocate workspace
|
||||||
|
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||||
|
torch::Tensor workspace =
|
||||||
|
torch::empty(workspace_size,
|
||||||
|
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||||
|
|
||||||
|
// Run GEMM
|
||||||
|
GemmShuffled gemm;
|
||||||
|
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||||
|
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||||
|
CUTLASS_CHECK(gemm.run(stream));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Kernel instantiations and dispatch logic
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||||
|
using CoopEpi = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||||
|
|
||||||
|
// Kernel_TileShape_ClusterShape_Schedule
|
||||||
|
using Kernel_128x16_1x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||||
|
using Kernel_128x16_2x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_128, _16>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||||
|
|
||||||
|
using Kernel_256x16_1x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||||
|
using Kernel_256x16_2x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _16>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||||
|
|
||||||
|
using Kernel_256x32_1x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||||
|
using Kernel_256x32_2x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _32>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||||
|
|
||||||
|
using Kernel_256x64_1x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||||
|
using Kernel_256x64_2x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _64>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||||
|
|
||||||
|
using Kernel_256x128_1x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||||
|
using Kernel_256x128_2x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_256, _128>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||||
|
|
||||||
|
using Kernel_128x256_2x1x1_Coop =
|
||||||
|
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||||
|
|
||||||
|
void mm_dispatch(
|
||||||
|
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||||
|
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||||
|
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||||
|
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||||
|
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||||
|
const torch::Tensor& group_scale_strides, const std::string& schedule) {
|
||||||
|
if (schedule == "Kernel_128x16_1x1x1_Coop") {
|
||||||
|
Kernel_128x16_1x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_128x16_2x1x1_Coop") {
|
||||||
|
Kernel_128x16_2x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x16_1x1x1_Coop") {
|
||||||
|
Kernel_256x16_1x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x16_2x1x1_Coop") {
|
||||||
|
Kernel_256x16_2x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x32_1x1x1_Coop") {
|
||||||
|
Kernel_256x32_1x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x32_2x1x1_Coop") {
|
||||||
|
Kernel_256x32_2x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x64_1x1x1_Coop") {
|
||||||
|
Kernel_256x64_1x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x64_2x1x1_Coop") {
|
||||||
|
Kernel_256x64_2x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x128_1x1x1_Coop") {
|
||||||
|
Kernel_256x128_1x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_256x128_2x1x1_Coop") {
|
||||||
|
Kernel_256x128_2x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else if (schedule == "Kernel_128x256_2x1x1_Coop") {
|
||||||
|
Kernel_128x256_2x1x1_Coop::grouped_mm(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, group_scale_strides);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||||
|
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||||
|
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||||
|
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||||
|
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||||
|
const torch::Tensor& group_scale_strides,
|
||||||
|
std::optional<std::string> maybe_schedule) {
|
||||||
|
// user has specified a schedule
|
||||||
|
if (maybe_schedule) {
|
||||||
|
mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
|
b_group_scales, b_group_size, expert_offsets, problem_sizes,
|
||||||
|
a_strides, b_strides, c_strides, group_scale_strides,
|
||||||
|
*maybe_schedule);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// use heuristic
|
||||||
|
int m_full = a_tensors.size(0);
|
||||||
|
int n = b_tensors.size(1);
|
||||||
|
int k = b_tensors.size(2) * PackFactor; // logical k
|
||||||
|
int num_experts = b_tensors.size(0);
|
||||||
|
// per-expert batch size assuming uniform distribution
|
||||||
|
int m_expert = m_full / num_experts;
|
||||||
|
|
||||||
|
std::string schedule;
|
||||||
|
if (m_expert <= 16) {
|
||||||
|
schedule = "Kernel_128x16_2x1x1_Coop";
|
||||||
|
} else if (m_expert <= 32) {
|
||||||
|
schedule = "Kernel_256x32_1x1x1_Coop";
|
||||||
|
} else if (m_expert <= 64) {
|
||||||
|
schedule = "Kernel_256x64_1x1x1_Coop";
|
||||||
|
} else if (m_expert <= 128) {
|
||||||
|
schedule = "Kernel_256x128_2x1x1_Coop";
|
||||||
|
} else { // m_expert > 128
|
||||||
|
schedule = "Kernel_128x256_2x1x1_Coop";
|
||||||
|
}
|
||||||
|
|
||||||
|
mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
|
b_group_scales, b_group_size, expert_offsets, problem_sizes,
|
||||||
|
a_strides, b_strides, c_strides, group_scale_strides, schedule);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||||
|
torch::Tensor const& b_tensors) {
|
||||||
|
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
|
||||||
|
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
||||||
|
TORCH_CHECK(b_tensors.is_contiguous());
|
||||||
|
TORCH_CHECK(b_tensors.is_cuda());
|
||||||
|
|
||||||
|
int n = static_cast<int>(b_tensors.size(1));
|
||||||
|
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
|
||||||
|
|
||||||
|
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
|
||||||
|
// These misalignments cause silent OOB unless run under Compute Sanitizer.
|
||||||
|
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
||||||
|
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
||||||
|
|
||||||
|
// we will store the layout to an int32 tensor;
|
||||||
|
// this is the number of elements we need per layout
|
||||||
|
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
|
||||||
|
|
||||||
|
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
|
||||||
|
int num_experts = static_cast<int>(b_tensors.size(0));
|
||||||
|
|
||||||
|
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
|
||||||
|
auto b_packed_ptr = static_cast<QuantType*>(b_tensors_packed.data_ptr());
|
||||||
|
|
||||||
|
// multiply by ull so result does not overflow int32
|
||||||
|
size_t num_int4_elems = 1ull * num_experts * n * k;
|
||||||
|
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
|
||||||
|
num_int4_elems);
|
||||||
|
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||||
|
|
||||||
|
// construct the layout once; assumes each expert has the same layout
|
||||||
|
using LayoutType = LayoutB_Reordered;
|
||||||
|
std::vector<LayoutType> layout_B_reordered_host(num_experts);
|
||||||
|
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, Int<1>{}});
|
||||||
|
auto shape_B = cute::make_shape(n, k, Int<1>{});
|
||||||
|
auto layout_B = make_layout(shape_B, stride_B);
|
||||||
|
LayoutType layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||||
|
|
||||||
|
// reorder weights for each expert
|
||||||
|
for (int i = 0; i < num_experts; i++) {
|
||||||
|
// since the storage type of int4b is 1 byte but one element is 4 bits
|
||||||
|
// we need to adjust the offset
|
||||||
|
int64_t offset =
|
||||||
|
1ull * i * n * k * cutlass::sizeof_bits<QuantType>::value / 8;
|
||||||
|
cutlass::reorder_tensor(b_packed_ptr + offset, layout_B,
|
||||||
|
layout_B_reordered);
|
||||||
|
}
|
||||||
|
|
||||||
|
// save the packed layout to torch tensor so we can re-use it
|
||||||
|
auto cpu_opts =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
|
||||||
|
torch::Tensor layout_cpu =
|
||||||
|
torch::empty({num_experts, layout_width}, cpu_opts);
|
||||||
|
|
||||||
|
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
|
||||||
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
|
std::memcpy(layout_data + i * layout_width, // dst (int32*)
|
||||||
|
&layout_B_reordered, // src (LayoutType*)
|
||||||
|
sizeof(LayoutType)); // number of bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor packed_layout =
|
||||||
|
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
|
||||||
|
|
||||||
|
return {b_tensors_packed, packed_layout};
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("cutlass_w4a8_moe_mm", &mm);
|
||||||
|
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm::cutlass_w4a8_moe
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -7,6 +7,7 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include "cutlass_extensions/torch_utils.hpp"
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
|
#include "w4a8_utils.cuh"
|
||||||
|
|
||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
|
|
||||||
@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
|||||||
return packed_scales;
|
return packed_scales;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
GPU-accelerated implementation of cutlass::unified_encode_int4b.
|
|
||||||
Constructs a lookup table in constant memory to map 8 bits
|
|
||||||
(two 4-bit values) at a time. Assumes memory is contiguous
|
|
||||||
and pointers are 16-byte aligned.
|
|
||||||
*/
|
|
||||||
__constant__ uint8_t kNibbleLUT[256];
|
|
||||||
|
|
||||||
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
|
|
||||||
size_t nbytes) {
|
|
||||||
constexpr size_t V = sizeof(uint4); // 16 bytes
|
|
||||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
|
|
||||||
const size_t nvec = nbytes / V;
|
|
||||||
|
|
||||||
// 1-D grid-stride loop over 16-byte chunks
|
|
||||||
for (size_t vec = tid; vec < nvec; vec += nthreads) {
|
|
||||||
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
|
|
||||||
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
|
|
||||||
reinterpret_cast<uint4*>(out)[vec] = v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool upload_lut() {
|
|
||||||
std::array<uint8_t, 256> lut{};
|
|
||||||
auto map_nib = [](uint8_t v) -> uint8_t {
|
|
||||||
// 1..7 -> (8 - v); keep 0 and 8..15
|
|
||||||
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
|
|
||||||
};
|
|
||||||
for (int b = 0; b < 256; ++b) {
|
|
||||||
uint8_t lo = b & 0xF;
|
|
||||||
uint8_t hi = (b >> 4) & 0xF;
|
|
||||||
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
|
|
||||||
}
|
|
||||||
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
|
|
||||||
/*offset=*/0, cudaMemcpyHostToDevice);
|
|
||||||
|
|
||||||
return (e == cudaSuccess);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool unified_encode_int4b(cutlass::int4b_t const* in,
|
|
||||||
cutlass::int4b_t* out, size_t num_int4_elems) {
|
|
||||||
// Build/upload LUT
|
|
||||||
if (!upload_lut()) return false;
|
|
||||||
|
|
||||||
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
|
|
||||||
"int4 storage must be 1 byte");
|
|
||||||
const size_t nbytes = num_int4_elems >> 1;
|
|
||||||
|
|
||||||
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
|
|
||||||
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
|
|
||||||
|
|
||||||
// kernel launch params
|
|
||||||
constexpr int block = 256;
|
|
||||||
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
|
|
||||||
int grid = int((nvec + block - 1) / block);
|
|
||||||
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
|
|
||||||
|
|
||||||
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
|
|
||||||
cudaError_t err = cudaGetLastError();
|
|
||||||
return (err == cudaSuccess);
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||||
TORCH_CHECK(B.dim() == 2);
|
TORCH_CHECK(B.dim() == 2);
|
||||||
@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
|||||||
LayoutB_Reordered layout_B_reordered =
|
LayoutB_Reordered layout_B_reordered =
|
||||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||||
|
|
||||||
bool ok =
|
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
|
||||||
vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
|
n * k);
|
||||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||||
|
|
||||||
|
|||||||
90
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
Normal file
90
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#include "w4a8_utils.cuh"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
namespace vllm::cutlass_w4a8_utils {
|
||||||
|
|
||||||
|
/*
|
||||||
|
GPU-accelerated implementation of cutlass::unified_encode_int4b.
|
||||||
|
Constructs a lookup table in constant memory to map 8 bits
|
||||||
|
(two 4-bit values) at a time. Assumes memory is contiguous
|
||||||
|
and pointers are 16-byte aligned.
|
||||||
|
*/
|
||||||
|
__constant__ uint8_t kNibbleLUT[256];
|
||||||
|
|
||||||
|
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
|
||||||
|
size_t nbytes) {
|
||||||
|
constexpr size_t V = sizeof(uint4); // 16 bytes
|
||||||
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
|
||||||
|
const size_t nvec = nbytes / V;
|
||||||
|
|
||||||
|
// 1-D grid-stride loop over 16-byte chunks
|
||||||
|
for (size_t vec = tid; vec < nvec; vec += nthreads) {
|
||||||
|
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
|
||||||
|
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
|
||||||
|
reinterpret_cast<uint4*>(out)[vec] = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool upload_lut() {
|
||||||
|
std::array<uint8_t, 256> lut{};
|
||||||
|
auto map_nib = [](uint8_t v) -> uint8_t {
|
||||||
|
// 1..7 -> (8 - v); keep 0 and 8..15
|
||||||
|
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
|
||||||
|
};
|
||||||
|
for (int b = 0; b < 256; ++b) {
|
||||||
|
uint8_t lo = b & 0xF;
|
||||||
|
uint8_t hi = (b >> 4) & 0xF;
|
||||||
|
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
|
||||||
|
}
|
||||||
|
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
|
||||||
|
/*offset=*/0, cudaMemcpyHostToDevice);
|
||||||
|
|
||||||
|
return (e == cudaSuccess);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
|
||||||
|
size_t num_int4_elems) {
|
||||||
|
// Build/upload LUT
|
||||||
|
if (!upload_lut()) return false;
|
||||||
|
|
||||||
|
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
|
||||||
|
"int4 storage must be 1 byte");
|
||||||
|
const size_t nbytes = num_int4_elems >> 1;
|
||||||
|
|
||||||
|
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
|
||||||
|
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
|
||||||
|
|
||||||
|
// kernel launch params
|
||||||
|
constexpr int block = 256;
|
||||||
|
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
|
||||||
|
int grid = int((nvec + block - 1) / block);
|
||||||
|
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
|
||||||
|
|
||||||
|
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
|
||||||
|
|
||||||
|
// launch errors
|
||||||
|
cudaError_t err = cudaGetLastError();
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
printf("unified_encode_int4b_device launch error: %s (%d)\n",
|
||||||
|
cudaGetErrorString(err), err);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// runtime errors
|
||||||
|
err = cudaDeviceSynchronize();
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
printf("unified_encode_int4b_device runtime error: %s (%d)\n",
|
||||||
|
cudaGetErrorString(err), err);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm::cutlass_w4a8_utils
|
||||||
11
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
Normal file
11
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
namespace vllm::cutlass_w4a8_utils {
|
||||||
|
|
||||||
|
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
|
||||||
|
size_t num_int4_elems);
|
||||||
|
|
||||||
|
} // namespace vllm::cutlass_w4a8_utils
|
||||||
@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
|||||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||||
|
std::optional<bool> force_swap_ab = std::nullopt) {
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||||
auto options_int32 =
|
auto options_int32 =
|
||||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||||
|
|
||||||
// Swap-AB should be disabled for FP4 path
|
// Swap-AB should be disabled for FP4 path
|
||||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
bool may_swap_ab =
|
||||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
force_swap_ab.value_or((!blockscale_offsets.has_value()) &&
|
||||||
|
(topk_ids.numel() <= SWAP_AB_THRESHOLD));
|
||||||
|
|
||||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||||
atomic_buffer, num_experts, n, k, stream,
|
atomic_buffer, num_experts, n, k, stream,
|
||||||
|
|||||||
@ -80,7 +80,8 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||||
|
std::optional<bool> force_swap_ab = std::nullopt);
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& problem_sizes1,
|
||||||
@ -303,14 +304,15 @@ void get_cutlass_moe_mm_data(
|
|||||||
void get_cutlass_moe_mm_problem_sizes(
|
void get_cutlass_moe_mm_problem_sizes(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||||
|
std::optional<bool> force_swap_ab = std::nullopt) {
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||||
problem_sizes2, num_experts, n, k,
|
problem_sizes2, num_experts, n, k,
|
||||||
blockscale_offsets);
|
blockscale_offsets, force_swap_ab);
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
|||||||
@ -350,6 +350,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
|
// CUTLASS w4a8 grouped GEMM
|
||||||
|
ops.def(
|
||||||
|
"cutlass_w4a8_moe_mm("
|
||||||
|
" Tensor! out_tensors,"
|
||||||
|
" Tensor a_tensors,"
|
||||||
|
" Tensor b_tensors,"
|
||||||
|
" Tensor a_scales,"
|
||||||
|
" Tensor b_scales,"
|
||||||
|
" Tensor b_group_scales,"
|
||||||
|
" int b_group_size,"
|
||||||
|
" Tensor expert_offsets,"
|
||||||
|
" Tensor problem_sizes,"
|
||||||
|
" Tensor a_strides,"
|
||||||
|
" Tensor b_strides,"
|
||||||
|
" Tensor c_strides,"
|
||||||
|
" Tensor group_scale_strides,"
|
||||||
|
" str? maybe_schedule"
|
||||||
|
") -> ()");
|
||||||
|
ops.def(
|
||||||
|
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
|
||||||
|
"Tensor)");
|
||||||
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Dequantization for GGML.
|
// Dequantization for GGML.
|
||||||
@ -466,7 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor! problem_sizes1, "
|
" Tensor! problem_sizes1, "
|
||||||
" Tensor! problem_sizes2, "
|
" Tensor! problem_sizes2, "
|
||||||
" int num_experts, int n, int k, "
|
" int num_experts, int n, int k, "
|
||||||
" Tensor? blockscale_offsets) -> ()");
|
" Tensor? blockscale_offsets, "
|
||||||
|
" bool? force_swap_ab) -> ()");
|
||||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||||
&get_cutlass_moe_mm_problem_sizes);
|
&get_cutlass_moe_mm_problem_sizes);
|
||||||
|
|
||||||
|
|||||||
@ -12,8 +12,11 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||||
|
pack_cols,
|
||||||
pack_rows,
|
pack_rows,
|
||||||
quantize_weights,
|
quantize_weights,
|
||||||
|
unpack_quantized_values_into_int32,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
@ -167,8 +170,7 @@ def create_test_tensors(
|
|||||||
|
|
||||||
# for the practical use case we need per-tok scales for fp8 activations
|
# for the practical use case we need per-tok scales for fp8 activations
|
||||||
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
|
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
|
||||||
# weights are already per-group quantized, use placeholder here
|
w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type)
|
||||||
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
|
|
||||||
|
|
||||||
return Tensors(
|
return Tensors(
|
||||||
w_ref=w_ref,
|
w_ref=w_ref,
|
||||||
@ -211,7 +213,7 @@ def mm_test_helper(
|
|||||||
print(output_ref)
|
print(output_ref)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
|
output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -257,7 +259,7 @@ def test_w4a8_cuda_graph():
|
|||||||
)
|
)
|
||||||
|
|
||||||
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
|
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
|
||||||
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
|
w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
# Construct a trivial model with a single layer that calls the kernel
|
# Construct a trivial model with a single layer that calls the kernel
|
||||||
model = W4A8Layer(
|
model = W4A8Layer(
|
||||||
@ -287,4 +289,38 @@ def test_w4a8_cuda_graph():
|
|||||||
output.zero_()
|
output.zero_()
|
||||||
g.replay()
|
g.replay()
|
||||||
|
|
||||||
torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("shape", MNK_SHAPES)
|
||||||
|
def test_convert_packed_uint4b8_to_signed_int4_inplace(shape):
|
||||||
|
"""
|
||||||
|
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
|
||||||
|
The CUTLASS kernels expect signed int4 packed to int32.
|
||||||
|
This tests checks that the runtime int4b8 -> signed int4 conversion
|
||||||
|
matches the offline conversion step exactly.
|
||||||
|
"""
|
||||||
|
_, N, K = shape
|
||||||
|
# random weights packed to int32
|
||||||
|
t = torch.randint(
|
||||||
|
low=torch.iinfo(torch.int32).min,
|
||||||
|
high=torch.iinfo(torch.int32).max + 1,
|
||||||
|
size=(N, K // 8),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute reference
|
||||||
|
unpacked = unpack_quantized_values_into_int32(
|
||||||
|
t.clone(), scalar_types.uint4b8, packed_dim=1
|
||||||
|
)
|
||||||
|
unpacked = unpacked - 8 # int4b8 -> signed int4
|
||||||
|
ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape)
|
||||||
|
|
||||||
|
out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone())
|
||||||
|
|
||||||
|
assert torch.equal(ref, out)
|
||||||
|
assert not torch.equal(ref, t)
|
||||||
|
|||||||
340
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
340
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
pack_rows,
|
||||||
|
quantize_weights,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_quantize(
|
||||||
|
atype: torch.dtype,
|
||||||
|
w: torch.Tensor,
|
||||||
|
wtype: ScalarType,
|
||||||
|
stype: torch.dtype | None,
|
||||||
|
group_size: int | None,
|
||||||
|
zero_points: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Quantize weights into W4 and compute reference dequantized weights.
|
||||||
|
|
||||||
|
Encoding/reordering of weights and packing of scales is deferred
|
||||||
|
until after all experts are combined.
|
||||||
|
"""
|
||||||
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||||
|
|
||||||
|
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||||
|
w, wtype, group_size=group_size, zero_points=zero_points
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since scales are later cast to fp8, recompute w_ref in atype here.
|
||||||
|
w_ref = (
|
||||||
|
w_q.to(torch.float32)
|
||||||
|
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
||||||
|
).to(atype)
|
||||||
|
|
||||||
|
# Bit mask prevents sign extension of int4 when packing.
|
||||||
|
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||||
|
# Make weights row-major (N, K).
|
||||||
|
w_q = w_q.t().contiguous()
|
||||||
|
|
||||||
|
return w_ref, w_q, w_s.to(atype), w_zp
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_preprocess(
|
||||||
|
w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reorder/encode expert weights and pack scales.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
w_q_packed: Packed/encoded int4 weights for all experts.
|
||||||
|
w_s_packed: Packed fp8 scales for all experts.
|
||||||
|
packed_layout: Layout/stride metadata for grouped GEMM.
|
||||||
|
"""
|
||||||
|
w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts))
|
||||||
|
w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped(
|
||||||
|
torch.stack(w_q_experts)
|
||||||
|
) # expects dim 3
|
||||||
|
return w_q_packed, w_s_packed, packed_layout
|
||||||
|
|
||||||
|
|
||||||
|
GROUP_SIZE = 128
|
||||||
|
# (num_experts, N, K)
|
||||||
|
TEST_SHAPES = [
|
||||||
|
(8, 512, 2048),
|
||||||
|
(8, 2048, 2048),
|
||||||
|
(64, 512, 1024),
|
||||||
|
(64, 2048, 2048),
|
||||||
|
(4, 2048, 768),
|
||||||
|
(8, 768, 2048),
|
||||||
|
(64, 1536, 2048),
|
||||||
|
(128, 8192, 4096), # test overflow int32
|
||||||
|
]
|
||||||
|
ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MoETestSetup:
|
||||||
|
num_experts: int
|
||||||
|
K: int
|
||||||
|
N: int
|
||||||
|
Ms: list[int]
|
||||||
|
M_full: int
|
||||||
|
a: torch.Tensor
|
||||||
|
a_ref: torch.Tensor
|
||||||
|
a_strides: torch.Tensor
|
||||||
|
out: torch.Tensor
|
||||||
|
c_strides: torch.Tensor
|
||||||
|
per_tok_scales: torch.Tensor
|
||||||
|
per_chan_scales: torch.Tensor
|
||||||
|
w_refs: list[torch.Tensor]
|
||||||
|
w_q_packed: torch.Tensor
|
||||||
|
w_s_packed: torch.Tensor
|
||||||
|
problem_sizes: torch.Tensor
|
||||||
|
expert_offsets: torch.Tensor
|
||||||
|
b_strides: torch.Tensor
|
||||||
|
group_scale_strides: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def make_moe_test_setup(
|
||||||
|
num_experts: int,
|
||||||
|
K: int,
|
||||||
|
N: int,
|
||||||
|
*,
|
||||||
|
alignment: int = ALIGNMENT,
|
||||||
|
max_blocks: int = 64,
|
||||||
|
device: str = "cuda",
|
||||||
|
random_zero: bool = False,
|
||||||
|
) -> MoETestSetup:
|
||||||
|
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
|
||||||
|
|
||||||
|
assert K % GROUP_SIZE == 0
|
||||||
|
# Token counts per expert (multiples of `alignment`).
|
||||||
|
Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)]
|
||||||
|
|
||||||
|
# set random experts to 0 tokens
|
||||||
|
if random_zero and num_experts > 1:
|
||||||
|
num_zero = max(1, num_experts // 8)
|
||||||
|
zero_indices = random.sample(range(num_experts), k=num_zero)
|
||||||
|
for idx in zero_indices:
|
||||||
|
Ms[idx] = 0
|
||||||
|
|
||||||
|
M_full = sum(Ms)
|
||||||
|
assert M_full > 0
|
||||||
|
|
||||||
|
# Activations.
|
||||||
|
a = to_fp8(torch.randn((M_full, K), device=device))
|
||||||
|
a_ref = a.to(torch.float32)
|
||||||
|
a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
# Output buffer.
|
||||||
|
out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device)
|
||||||
|
c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
# Channel/token scales.
|
||||||
|
per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device)
|
||||||
|
per_chan_scales = torch.randn(
|
||||||
|
(num_experts, N, 1), dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expert weights and scales.
|
||||||
|
wtype = scalar_types.int4
|
||||||
|
atype = stype = torch.float8_e4m3fn
|
||||||
|
w_refs, w_qs, w_ss = [], [], []
|
||||||
|
for _ in range(num_experts):
|
||||||
|
b = to_fp8(torch.randn((K, N), device=device))
|
||||||
|
w_ref, w_q, w_s, _ = cutlass_quantize(
|
||||||
|
atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False
|
||||||
|
)
|
||||||
|
w_refs.append(w_ref)
|
||||||
|
w_qs.append(w_q)
|
||||||
|
w_ss.append(w_s)
|
||||||
|
|
||||||
|
w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss)
|
||||||
|
|
||||||
|
problem_sizes = torch.tensor(
|
||||||
|
[[N, M, K] for M in Ms], dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
expert_offsets = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([0], dtype=torch.int64),
|
||||||
|
torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1],
|
||||||
|
]
|
||||||
|
).to(device=device)
|
||||||
|
|
||||||
|
# B strides and group scale strides.
|
||||||
|
b_strides = packed_layout
|
||||||
|
group_scale_strides = torch.zeros(
|
||||||
|
(num_experts, 2), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
group_scale_strides[:, 0] = N
|
||||||
|
|
||||||
|
return MoETestSetup(
|
||||||
|
num_experts=num_experts,
|
||||||
|
K=K,
|
||||||
|
N=N,
|
||||||
|
Ms=Ms,
|
||||||
|
M_full=M_full,
|
||||||
|
a=a,
|
||||||
|
a_ref=a_ref,
|
||||||
|
a_strides=a_strides,
|
||||||
|
out=out,
|
||||||
|
c_strides=c_strides,
|
||||||
|
per_tok_scales=per_tok_scales,
|
||||||
|
per_chan_scales=per_chan_scales,
|
||||||
|
w_refs=w_refs,
|
||||||
|
w_q_packed=w_q_packed,
|
||||||
|
w_s_packed=w_s_packed,
|
||||||
|
problem_sizes=problem_sizes,
|
||||||
|
expert_offsets=expert_offsets,
|
||||||
|
b_strides=b_strides,
|
||||||
|
group_scale_strides=group_scale_strides,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor:
|
||||||
|
"""Compute reference output using torch._scaled_mm per expert."""
|
||||||
|
out_ref = torch.empty_like(setup.out)
|
||||||
|
|
||||||
|
ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist()
|
||||||
|
starts = setup.expert_offsets.cpu().tolist()
|
||||||
|
|
||||||
|
for i in range(setup.num_experts):
|
||||||
|
start, end = starts[i], ends[i]
|
||||||
|
if start == end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_ref_i = torch._scaled_mm(
|
||||||
|
setup.a_ref[start:end].to(torch.float8_e4m3fn),
|
||||||
|
setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(),
|
||||||
|
setup.per_tok_scales[start:end], # (M, 1)
|
||||||
|
setup.per_chan_scales[i].reshape(1, -1), # (1, N)
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=True,
|
||||||
|
)
|
||||||
|
out_ref[start:end] = out_ref_i
|
||||||
|
|
||||||
|
return out_ref
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not IS_SUPPORTED_BY_GPU,
|
||||||
|
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("shape", TEST_SHAPES)
|
||||||
|
@pytest.mark.parametrize("random_zero", [True, False])
|
||||||
|
def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
|
||||||
|
num_experts, N, K = shape
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
setup = make_moe_test_setup(
|
||||||
|
num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.cutlass_w4a8_moe_mm(
|
||||||
|
setup.out,
|
||||||
|
setup.a,
|
||||||
|
setup.w_q_packed,
|
||||||
|
setup.per_tok_scales,
|
||||||
|
setup.per_chan_scales,
|
||||||
|
setup.w_s_packed,
|
||||||
|
GROUP_SIZE,
|
||||||
|
setup.expert_offsets,
|
||||||
|
setup.problem_sizes,
|
||||||
|
setup.a_strides,
|
||||||
|
setup.b_strides,
|
||||||
|
setup.c_strides,
|
||||||
|
setup.group_scale_strides,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
out_ref = compute_moe_reference_output(setup)
|
||||||
|
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class W4A8MoELayer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Minimal wrapper module to test cuda graphs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, setup: MoETestSetup):
|
||||||
|
super().__init__()
|
||||||
|
self.setup = setup
|
||||||
|
|
||||||
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
||||||
|
s = self.setup
|
||||||
|
ops.cutlass_w4a8_moe_mm(
|
||||||
|
s.out,
|
||||||
|
a,
|
||||||
|
s.w_q_packed,
|
||||||
|
s.per_tok_scales,
|
||||||
|
s.per_chan_scales,
|
||||||
|
s.w_s_packed,
|
||||||
|
GROUP_SIZE,
|
||||||
|
s.expert_offsets,
|
||||||
|
s.problem_sizes,
|
||||||
|
s.a_strides,
|
||||||
|
s.b_strides,
|
||||||
|
s.c_strides,
|
||||||
|
s.group_scale_strides,
|
||||||
|
)
|
||||||
|
return s.out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not IS_SUPPORTED_BY_GPU,
|
||||||
|
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||||
|
)
|
||||||
|
def test_cutlass_w4a8_moe_mm_cuda_graph():
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
# Fixed config for CUDA graph test (single parameter point).
|
||||||
|
num_experts = 8
|
||||||
|
K = 512
|
||||||
|
N = 2048
|
||||||
|
|
||||||
|
setup = make_moe_test_setup(
|
||||||
|
num_experts=num_experts,
|
||||||
|
K=K,
|
||||||
|
N=N,
|
||||||
|
max_blocks=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct model that calls the grouped GEMM kernel.
|
||||||
|
model = W4A8MoELayer(setup)
|
||||||
|
|
||||||
|
# Build reference output once.
|
||||||
|
out_ref = compute_moe_reference_output(setup)
|
||||||
|
|
||||||
|
# Capture and run the model in a CUDA graph.
|
||||||
|
a_static = setup.a.clone() # static input tensor for graph replay
|
||||||
|
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
out_static = model(a_static)
|
||||||
|
|
||||||
|
out_static.zero_()
|
||||||
|
g.replay()
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2)
|
||||||
@ -695,6 +695,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
|
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
@register_fake("_C::cutlass_encode_and_reorder_int4b_grouped")
|
||||||
|
def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
|
||||||
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
|
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
|
||||||
|
|
||||||
@ -1058,6 +1062,7 @@ def get_cutlass_moe_mm_problem_sizes(
|
|||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
blockscale_offsets: torch.Tensor | None = None,
|
blockscale_offsets: torch.Tensor | None = None,
|
||||||
|
force_swap_ab: bool | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute only the per-expert problem sizes needed by the two grouped matrix
|
Compute only the per-expert problem sizes needed by the two grouped matrix
|
||||||
@ -1067,9 +1072,20 @@ def get_cutlass_moe_mm_problem_sizes(
|
|||||||
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
|
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
|
||||||
multiplication for the two grouped MMs
|
multiplication for the two grouped MMs
|
||||||
used in the fused MoE operation.
|
used in the fused MoE operation.
|
||||||
|
Optional:
|
||||||
|
- force_swap_ab: If set to True or False, explicitly enable or disable the
|
||||||
|
A/B input swap optimization. If None (default), the swap
|
||||||
|
is selected automatically based on tensor sizes.
|
||||||
"""
|
"""
|
||||||
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
|
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
|
||||||
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets
|
topk_ids,
|
||||||
|
problem_sizes1,
|
||||||
|
problem_sizes2,
|
||||||
|
num_experts,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
blockscale_offsets,
|
||||||
|
force_swap_ab,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1457,6 +1473,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
|
|||||||
return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
|
return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_w4a8_moe_mm(
|
||||||
|
out_tensors: torch.Tensor,
|
||||||
|
a_tensors: torch.Tensor,
|
||||||
|
b_tensors: torch.Tensor,
|
||||||
|
a_scales: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
b_group_scales: torch.Tensor,
|
||||||
|
b_group_size: int,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
problem_sizes: torch.Tensor,
|
||||||
|
a_strides: torch.Tensor,
|
||||||
|
b_strides: torch.Tensor,
|
||||||
|
c_strides: torch.Tensor,
|
||||||
|
group_scale_strides: torch.Tensor,
|
||||||
|
maybe_schedule: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
|
||||||
|
W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
|
||||||
|
and both per-channel + per-token scaling in the epilogue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_tensors:
|
||||||
|
Output buffer for all experts (updated in-place).
|
||||||
|
a_tensors:
|
||||||
|
FP8 (E4M3FN) activations for all experts.
|
||||||
|
b_tensors:
|
||||||
|
INT4-packed weight matrix for all experts, packed to INT32
|
||||||
|
a_scales:
|
||||||
|
Per-token FP8 activation scales, applied in the epilogue.
|
||||||
|
b_scales:
|
||||||
|
Per-channel FP8 weight scales for each expert, applied in the epilogue.
|
||||||
|
b_group_scales:
|
||||||
|
FP8 scale values for group-wise INT4 weight blocks.
|
||||||
|
b_group_size:
|
||||||
|
Number of elements grouped under each entry of b_group_scales.
|
||||||
|
expert_offsets:
|
||||||
|
Cumulative token offsets
|
||||||
|
problem_sizes:
|
||||||
|
Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
|
||||||
|
a/b/c/group_scale_strides:
|
||||||
|
Strides describing the memory layout of the input tensors.
|
||||||
|
maybe_schedule:
|
||||||
|
Optional override to choose a specific kernel or epilogue schedule.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
|
||||||
|
"""
|
||||||
|
return torch.ops._C.cutlass_w4a8_moe_mm(
|
||||||
|
out_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
b_group_scales,
|
||||||
|
b_group_size,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
c_strides,
|
||||||
|
group_scale_strides,
|
||||||
|
maybe_schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_encode_and_reorder_int4b_grouped(
|
||||||
|
b_tensors: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors)
|
||||||
|
|
||||||
|
|
||||||
if hasattr(torch.ops._C, "permute_cols"):
|
if hasattr(torch.ops._C, "permute_cols"):
|
||||||
|
|
||||||
@register_fake("_C::permute_cols")
|
@register_fake("_C::permute_cols")
|
||||||
|
|||||||
@ -63,8 +63,10 @@ if HAS_TRITON:
|
|||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
CutlassBatchedExpertsFp8,
|
CutlassBatchedExpertsFp8,
|
||||||
CutlassExpertsFp8,
|
CutlassExpertsFp8,
|
||||||
|
CutlassExpertsW4A8Fp8,
|
||||||
cutlass_moe_fp4,
|
cutlass_moe_fp4,
|
||||||
cutlass_moe_fp8,
|
cutlass_moe_fp8,
|
||||||
|
cutlass_moe_w4a8_fp8,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
@ -88,8 +90,10 @@ if HAS_TRITON:
|
|||||||
"grouped_topk",
|
"grouped_topk",
|
||||||
"cutlass_moe_fp8",
|
"cutlass_moe_fp8",
|
||||||
"cutlass_moe_fp4",
|
"cutlass_moe_fp4",
|
||||||
|
"cutlass_moe_w4a8_fp8",
|
||||||
"CutlassExpertsFp8",
|
"CutlassExpertsFp8",
|
||||||
"CutlassBatchedExpertsFp8",
|
"CutlassBatchedExpertsFp8",
|
||||||
|
"CutlassExpertsW4A8Fp8",
|
||||||
"TritonExperts",
|
"TritonExperts",
|
||||||
"BatchedTritonExperts",
|
"BatchedTritonExperts",
|
||||||
"DeepGemmExperts",
|
"DeepGemmExperts",
|
||||||
|
|||||||
@ -143,6 +143,7 @@ class FusedMoEQuantDesc:
|
|||||||
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
|
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
|
||||||
|
|
||||||
# Quantization alphas or gscales, used for nvfp4 types.
|
# Quantization alphas or gscales, used for nvfp4 types.
|
||||||
|
# W4A8 FP8: used for per-channel scales
|
||||||
# TODO(bnell): put some of these in subclasses
|
# TODO(bnell): put some of these in subclasses
|
||||||
alpha_or_gscale: torch.Tensor | None = None
|
alpha_or_gscale: torch.Tensor | None = None
|
||||||
|
|
||||||
@ -442,7 +443,9 @@ class FusedMoEQuantConfig:
|
|||||||
- a1_scale: Optional scale to be used for a1.
|
- a1_scale: Optional scale to be used for a1.
|
||||||
- a2_scale: Optional scale to be used for a2.
|
- a2_scale: Optional scale to be used for a2.
|
||||||
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
|
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
|
||||||
|
per-channel scales for w1 (for W4A8 FP8).
|
||||||
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
|
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
|
||||||
|
per-channel scales for w2 (for W4A8 FP8).
|
||||||
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
|
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
|
||||||
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
|
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
|
||||||
- w1_bias: Optional biases for w1 (GPT OSS Triton).
|
- w1_bias: Optional biases for w1 (GPT OSS Triton).
|
||||||
@ -461,6 +464,7 @@ class FusedMoEQuantConfig:
|
|||||||
"mxfp4",
|
"mxfp4",
|
||||||
"mxfp6_e3m2",
|
"mxfp6_e3m2",
|
||||||
"mxfp6_e2m3",
|
"mxfp6_e2m3",
|
||||||
|
"int4",
|
||||||
}
|
}
|
||||||
|
|
||||||
if weight_dtype is None:
|
if weight_dtype is None:
|
||||||
@ -671,6 +675,31 @@ def int8_w8a16_moe_quant_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def int4_w4afp8_moe_quant_config(
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
g1_alphas: torch.Tensor,
|
||||||
|
g2_alphas: torch.Tensor,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
per_out_ch_quant: bool = False,
|
||||||
|
block_shape: list[int] | None = None,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for fp8 activations and int4 weights.
|
||||||
|
"""
|
||||||
|
return FusedMoEQuantConfig.make(
|
||||||
|
torch.float8_e4m3fn, # quant dtype for activations
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
g1_alphas=g1_alphas,
|
||||||
|
g2_alphas=g2_alphas,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
weight_dtype="int4", # weight dtype for weights
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def biased_moe_quant_config(
|
def biased_moe_quant_config(
|
||||||
w1_bias: torch.Tensor | None,
|
w1_bias: torch.Tensor | None,
|
||||||
w2_bias: torch.Tensor | None,
|
w2_bias: torch.Tensor | None,
|
||||||
|
|||||||
@ -1052,3 +1052,404 @@ def run_cutlass_block_scaled_fused_experts(
|
|||||||
return (
|
return (
|
||||||
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||||
).sum(dim=1)
|
).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
# W4A8
|
||||||
|
def run_cutlass_moe_w4a8_fp8(
|
||||||
|
output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation_callable: Callable,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: torch.Tensor | None,
|
||||||
|
w1_scale: torch.Tensor | None,
|
||||||
|
w2_scale: torch.Tensor | None,
|
||||||
|
a1q_scale: torch.Tensor | None,
|
||||||
|
a2_scale: torch.Tensor | None,
|
||||||
|
w1_chan_scale: torch.Tensor,
|
||||||
|
w2_chan_scale: torch.Tensor,
|
||||||
|
a_strides1: torch.Tensor,
|
||||||
|
a_strides2: torch.Tensor,
|
||||||
|
b_strides1: torch.Tensor,
|
||||||
|
b_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
s_strides1: torch.Tensor,
|
||||||
|
s_strides2: torch.Tensor,
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: torch.Tensor | None,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
use_batched_format: bool,
|
||||||
|
topk_weights: torch.Tensor | None,
|
||||||
|
group_size: int,
|
||||||
|
):
|
||||||
|
a1q = hidden_states
|
||||||
|
M = a1q.size(0)
|
||||||
|
local_E = w1.size(0)
|
||||||
|
device = a1q.device
|
||||||
|
_, K, N_packed = w2.shape
|
||||||
|
N = N_packed * 8 # logical N, pack 8 int4 into 1 int32
|
||||||
|
|
||||||
|
assert per_act_token, "W4A8 must use per-token scales"
|
||||||
|
assert per_out_ch, "W4A8 must use per-channel scales"
|
||||||
|
assert w1_scale is not None
|
||||||
|
assert w2_scale is not None
|
||||||
|
assert w1_scale.dtype == torch.float8_e4m3fn
|
||||||
|
assert w2_scale.dtype == torch.float8_e4m3fn
|
||||||
|
assert w1.dtype == torch.int32
|
||||||
|
assert w2.dtype == torch.int32
|
||||||
|
assert w1_chan_scale.dtype == torch.float32
|
||||||
|
assert w2_chan_scale.dtype == torch.float32
|
||||||
|
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
|
||||||
|
assert a1q_scale is not None
|
||||||
|
assert a2_scale is None
|
||||||
|
assert out_dtype in [torch.bfloat16], f"Invalid output dtype: {out_dtype}"
|
||||||
|
if expert_map is not None:
|
||||||
|
assert expert_num_tokens is None
|
||||||
|
assert not use_batched_format, "batched format not supported yet"
|
||||||
|
assert group_size == 128, f"Only group size 128 supported but got {group_size=}"
|
||||||
|
|
||||||
|
assert global_num_experts != -1
|
||||||
|
assert w1.size(2) * 8 == K, (
|
||||||
|
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Translate info from expert_map to topk_ids
|
||||||
|
if expert_map is not None:
|
||||||
|
local_topk_ids = torch.where(
|
||||||
|
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
local_topk_ids = topk_ids
|
||||||
|
|
||||||
|
topk = local_topk_ids.size(1)
|
||||||
|
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
|
||||||
|
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
|
||||||
|
act_out = _resize_cache(workspace2, (M * topk, N))
|
||||||
|
# original workspace are based on input hidden_states dtype (bf16)
|
||||||
|
quant_out = _resize_cache(
|
||||||
|
workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N)
|
||||||
|
)
|
||||||
|
mm2_out = _resize_cache(workspace2, (M * topk, K))
|
||||||
|
|
||||||
|
problem_sizes1 = torch.empty(
|
||||||
|
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
problem_sizes2 = torch.empty(
|
||||||
|
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||||
|
# permuted a1q reuses workspace2
|
||||||
|
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
|
||||||
|
a1q,
|
||||||
|
a1q_scale,
|
||||||
|
topk_ids,
|
||||||
|
num_expert,
|
||||||
|
local_E,
|
||||||
|
expert_map,
|
||||||
|
permuted_hidden_states=a1q_perm,
|
||||||
|
)
|
||||||
|
expert_offsets = expert_offsets[:-1]
|
||||||
|
|
||||||
|
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
|
||||||
|
ops.get_cutlass_moe_mm_problem_sizes(
|
||||||
|
local_topk_ids,
|
||||||
|
problem_sizes1,
|
||||||
|
problem_sizes2,
|
||||||
|
global_num_experts,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
force_swap_ab=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.cutlass_w4a8_moe_mm(
|
||||||
|
mm1_out,
|
||||||
|
a1q,
|
||||||
|
w1,
|
||||||
|
a1q_scale,
|
||||||
|
w1_chan_scale,
|
||||||
|
w1_scale,
|
||||||
|
group_size,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes1,
|
||||||
|
a_strides1,
|
||||||
|
b_strides1,
|
||||||
|
c_strides1,
|
||||||
|
s_strides1,
|
||||||
|
)
|
||||||
|
|
||||||
|
activation_callable(act_out, mm1_out)
|
||||||
|
|
||||||
|
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||||
|
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
||||||
|
)
|
||||||
|
|
||||||
|
if expert_map is not None:
|
||||||
|
mm2_out.fill_(0)
|
||||||
|
|
||||||
|
ops.cutlass_w4a8_moe_mm(
|
||||||
|
mm2_out,
|
||||||
|
a2q,
|
||||||
|
w2,
|
||||||
|
a2q_scale,
|
||||||
|
w2_chan_scale,
|
||||||
|
w2_scale,
|
||||||
|
group_size,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes2,
|
||||||
|
a_strides2,
|
||||||
|
b_strides2,
|
||||||
|
c_strides2,
|
||||||
|
s_strides2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# for non-chunking mode the output is resized from workspace13
|
||||||
|
# so we need to make sure mm2_out uses workspace2.
|
||||||
|
moe_unpermute(
|
||||||
|
out=output,
|
||||||
|
permuted_hidden_states=mm2_out,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
inv_permuted_idx=inv_perm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_dtype: torch.dtype | None,
|
||||||
|
a_strides1: torch.Tensor,
|
||||||
|
a_strides2: torch.Tensor,
|
||||||
|
b_strides1: torch.Tensor,
|
||||||
|
b_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
s_strides1: torch.Tensor,
|
||||||
|
s_strides2: torch.Tensor,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
group_size: int,
|
||||||
|
):
|
||||||
|
super().__init__(quant_config)
|
||||||
|
self.out_dtype = out_dtype
|
||||||
|
self.a_strides1 = a_strides1
|
||||||
|
self.a_strides2 = a_strides2
|
||||||
|
self.b_strides1 = b_strides1
|
||||||
|
self.b_strides2 = b_strides2
|
||||||
|
self.c_strides1 = c_strides1
|
||||||
|
self.c_strides2 = c_strides2
|
||||||
|
self.s_strides1 = s_strides1
|
||||||
|
self.s_strides2 = s_strides2
|
||||||
|
self.group_size = group_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self,
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (
|
||||||
|
mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard,
|
||||||
|
)
|
||||||
|
|
||||||
|
def supports_chunking(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||||
|
return TopKWeightAndReduceNoOP()
|
||||||
|
|
||||||
|
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||||
|
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
|
workspace1 = (M * topk, max(N, K))
|
||||||
|
workspace2 = (M * topk, max(N // 2, K))
|
||||||
|
output = (M, K)
|
||||||
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: torch.Tensor | None,
|
||||||
|
a1q_scale: torch.Tensor | None,
|
||||||
|
a2_scale: torch.Tensor | None,
|
||||||
|
workspace13: torch.Tensor | None,
|
||||||
|
workspace2: torch.Tensor | None,
|
||||||
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
|
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||||
|
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||||
|
|
||||||
|
expert_num_tokens = None
|
||||||
|
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||||
|
|
||||||
|
use_batched_format = (
|
||||||
|
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
)
|
||||||
|
assert not use_batched_format, "batched format not supported"
|
||||||
|
|
||||||
|
in_dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
run_cutlass_moe_w4a8_fp8(
|
||||||
|
output,
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_ids,
|
||||||
|
activation_callable,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
self.w1_scale,
|
||||||
|
self.w2_scale,
|
||||||
|
a1q_scale,
|
||||||
|
a2_scale,
|
||||||
|
self.g1_alphas, # per-channel scales
|
||||||
|
self.g2_alphas, # per-channel scales
|
||||||
|
self.a_strides1,
|
||||||
|
self.a_strides2,
|
||||||
|
self.b_strides1,
|
||||||
|
self.b_strides2,
|
||||||
|
self.c_strides1,
|
||||||
|
self.c_strides2,
|
||||||
|
self.s_strides1,
|
||||||
|
self.s_strides2,
|
||||||
|
workspace13,
|
||||||
|
workspace2,
|
||||||
|
expert_num_tokens,
|
||||||
|
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||||
|
self.per_act_token_quant,
|
||||||
|
self.per_out_ch_quant,
|
||||||
|
use_batched_format,
|
||||||
|
topk_weights,
|
||||||
|
self.group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_moe_w4a8_fp8(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1_q: torch.Tensor,
|
||||||
|
w2_q: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
a_strides1: torch.Tensor,
|
||||||
|
a_strides2: torch.Tensor,
|
||||||
|
b_strides1: torch.Tensor,
|
||||||
|
b_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
s_strides1: torch.Tensor,
|
||||||
|
s_strides2: torch.Tensor,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
activation: str = "silu",
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
group_size: int = 128,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
||||||
|
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||||
|
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||||
|
mixed-dtype grouped gemm.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
Shape: [M, K]
|
||||||
|
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||||
|
Shape: [num_experts, 2*N, K // packed_factor]
|
||||||
|
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||||
|
Shape: [num_experts, K, N // packed_factor]
|
||||||
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- topk_ids (torch.Tensor): The token->expert mappings.
|
||||||
|
- a_strides1 (torch.Tensor): The input strides for the first gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- a_strides2 (torch.Tensor): The input strides for the second gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- b_strides1 (torch.Tensor): The packed layout for the first gemm weights.
|
||||||
|
Shape: [num_experts, 3]
|
||||||
|
dtype: torch.int32
|
||||||
|
- b_strides2 (torch.Tensor): The packed layout for the second gemm weights.
|
||||||
|
Shape: [num_experts, 3]
|
||||||
|
dtype: torch.int32
|
||||||
|
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm.
|
||||||
|
Shape: [num_experts, 2]
|
||||||
|
dtype: torch.int64
|
||||||
|
- s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm.
|
||||||
|
Shape: [num_experts, 2]
|
||||||
|
dtype: torch.int64
|
||||||
|
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||||
|
per-tensor.
|
||||||
|
- activation (str): The activation function to use.
|
||||||
|
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||||
|
every Rank is responsible for a subset of experts. expert_map is a
|
||||||
|
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||||
|
is -1, it means that this Rank is not responsible for global
|
||||||
|
expert-id i.
|
||||||
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
|
applied directly on the inputs. This is only applicable when topk is 1.
|
||||||
|
- global_num_experts (int): The total number of experts.
|
||||||
|
- group_size (int): The number of weights per scale factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The bf16 output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
assert quant_config is not None
|
||||||
|
|
||||||
|
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
|
||||||
|
|
||||||
|
fn = mk.FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
|
CutlassExpertsW4A8Fp8(
|
||||||
|
out_dtype=a.dtype,
|
||||||
|
a_strides1=a_strides1,
|
||||||
|
a_strides2=a_strides2,
|
||||||
|
b_strides1=b_strides1,
|
||||||
|
b_strides2=b_strides2,
|
||||||
|
c_strides1=c_strides1,
|
||||||
|
c_strides2=c_strides2,
|
||||||
|
s_strides1=s_strides1,
|
||||||
|
s_strides2=s_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
|
group_size=group_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fn(
|
||||||
|
a,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|||||||
@ -256,7 +256,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
if format is not None
|
if format is not None
|
||||||
else is_activation_quantization_format(quant_format)
|
else is_activation_quantization_format(quant_format)
|
||||||
)
|
)
|
||||||
# TODO(czhu): w4a8fp8 is in packed-quantized format
|
# w4a8fp8 is in packed-quantized format
|
||||||
# but needs input activation quantization
|
# but needs input activation quantization
|
||||||
input_activations = quant_config.get("input_activations")
|
input_activations = quant_config.get("input_activations")
|
||||||
if act_quant_format or input_activations:
|
if act_quant_format or input_activations:
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
int4_w4a16_moe_quant_config,
|
int4_w4a16_moe_quant_config,
|
||||||
|
int4_w4afp8_moe_quant_config,
|
||||||
int8_w8a8_moe_quant_config,
|
int8_w8a8_moe_quant_config,
|
||||||
int8_w8a16_moe_quant_config,
|
int8_w8a16_moe_quant_config,
|
||||||
nvfp4_moe_quant_config,
|
nvfp4_moe_quant_config,
|
||||||
@ -79,7 +80,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
prepare_moe_fp8_layer_for_marlin,
|
prepare_moe_fp8_layer_for_marlin,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
convert_bf16_scales_to_fp8,
|
||||||
|
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||||
|
swizzle_blockscale,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
@ -204,6 +209,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
return CompressedTensorsW8A8Int8MoEMethod(
|
return CompressedTensorsW8A8Int8MoEMethod(
|
||||||
weight_quant, input_quant, layer.moe_config
|
weight_quant, input_quant, layer.moe_config
|
||||||
)
|
)
|
||||||
|
elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant):
|
||||||
|
logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod")
|
||||||
|
return CompressedTensorsW4A8Fp8MoEMethod(
|
||||||
|
weight_quant, input_quant, layer.moe_config
|
||||||
|
)
|
||||||
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A8Int8MoEMethod(
|
return CompressedTensorsW4A8Int8MoEMethod(
|
||||||
weight_quant, input_quant, layer.moe_config
|
weight_quant, input_quant, layer.moe_config
|
||||||
@ -2428,3 +2438,331 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
int(_act_kind(activation)),
|
int(_act_kind(activation)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_quant: QuantizationArgs,
|
||||||
|
input_quant: QuantizationArgs,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
layer_name: str | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
|
self.weight_quant = weight_quant
|
||||||
|
self.input_quant = input_quant
|
||||||
|
|
||||||
|
self.group_size = self.weight_quant.group_size
|
||||||
|
self.num_bits = self.weight_quant.num_bits
|
||||||
|
self.packed_factor = 32 // self.num_bits
|
||||||
|
|
||||||
|
assert self.weight_quant.symmetric, (
|
||||||
|
"Only symmetric quantization is supported for W4A8 MoE"
|
||||||
|
)
|
||||||
|
assert self.weight_quant.actorder != "group"
|
||||||
|
assert self.group_size == 128, "Only group size 128 supported for W4A8 MoE"
|
||||||
|
|
||||||
|
self.disable_expert_map = False
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||||
|
layer.hidden_size = hidden_size
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
|
||||||
|
# requirement for CUTLASS reorder_tensor
|
||||||
|
assert hidden_size % 256 == 0, f"{hidden_size=} must be divisible by 256"
|
||||||
|
assert intermediate_size_per_partition % 256 == 0, (
|
||||||
|
f"{intermediate_size_per_partition=} must be divisible by 256"
|
||||||
|
)
|
||||||
|
# storage type, pack 8xint4 into int32
|
||||||
|
params_dtype = torch.int32
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight_packed = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size // self.packed_factor,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_packed", w13_weight_packed)
|
||||||
|
set_weight_attrs(w13_weight_packed, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight_packed = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // self.packed_factor,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_packed", w2_weight_packed)
|
||||||
|
set_weight_attrs(w2_weight_packed, extra_weight_attrs)
|
||||||
|
|
||||||
|
# SCALES
|
||||||
|
# weight_scale refers to the group-wise scales
|
||||||
|
# they are initially loaded as bf16, we will convert to fp8
|
||||||
|
# after loading
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size // self.group_size,
|
||||||
|
dtype=layer.orig_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // self.group_size,
|
||||||
|
dtype=layer.orig_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add PER-GROUP quantization for FusedMoE.weight_loader.
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
|
||||||
|
)
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# weight shapes
|
||||||
|
w2_weight_shape = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, 2), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
||||||
|
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
||||||
|
w13_weight_shape = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, 2), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
||||||
|
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
||||||
|
|
||||||
|
# don't use input scales
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
device = layer.w13_weight_packed.device
|
||||||
|
|
||||||
|
# STRIDES
|
||||||
|
# A, C
|
||||||
|
self.a_strides1_c_strides2 = torch.full(
|
||||||
|
(layer.local_num_experts,),
|
||||||
|
layer.hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
self.a_strides2 = torch.full(
|
||||||
|
(layer.local_num_experts,),
|
||||||
|
layer.intermediate_size_per_partition,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
self.c_strides1 = torch.full(
|
||||||
|
(layer.local_num_experts,),
|
||||||
|
2 * layer.intermediate_size_per_partition,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
|
||||||
|
# S (group-wise scales)
|
||||||
|
# sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it
|
||||||
|
self.s_strides1 = torch.zeros(
|
||||||
|
(layer.local_num_experts, 2), device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
self.s_strides1[:, 0] = 2 * layer.intermediate_size_per_partition
|
||||||
|
|
||||||
|
self.s_strides2 = torch.zeros(
|
||||||
|
(layer.local_num_experts, 2), device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
self.s_strides2[:, 0] = layer.hidden_size
|
||||||
|
|
||||||
|
# encode and reorder weight tensors, and get the layout to pass to
|
||||||
|
# the grouped gemm kernel. `b_strides1/2` specifies the entire layout
|
||||||
|
convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed)
|
||||||
|
w13_weight_shuffled, self.b_strides1 = (
|
||||||
|
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed)
|
||||||
|
)
|
||||||
|
replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled)
|
||||||
|
convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed)
|
||||||
|
w2_weight_shuffled, self.b_strides2 = (
|
||||||
|
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed)
|
||||||
|
)
|
||||||
|
replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled)
|
||||||
|
|
||||||
|
# convert bf16 scales to (fp8_scales, channel_scales)
|
||||||
|
w13_weight_scale, w13_weight_chan_scale = convert_bf16_scales_to_fp8(
|
||||||
|
self.quant_fp8, layer.w13_weight_scale
|
||||||
|
)
|
||||||
|
w2_weight_scale, w2_weight_chan_scale = convert_bf16_scales_to_fp8(
|
||||||
|
self.quant_fp8, layer.w2_weight_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# register channel scales
|
||||||
|
layer.register_parameter(
|
||||||
|
"w13_weight_chan_scale",
|
||||||
|
torch.nn.Parameter(w13_weight_chan_scale, requires_grad=False),
|
||||||
|
)
|
||||||
|
layer.register_parameter(
|
||||||
|
"w2_weight_chan_scale",
|
||||||
|
torch.nn.Parameter(w2_weight_chan_scale, requires_grad=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The scales are stored as (E, N, K // 128) but the kernel expects
|
||||||
|
# (E, K // 128, N) in row-major format, so we need to permute the last 2 dims
|
||||||
|
# and make it contiguous
|
||||||
|
w13_weight_scale_packed = ops.cutlass_pack_scale_fp8(
|
||||||
|
w13_weight_scale.permute(0, 2, 1).contiguous()
|
||||||
|
)
|
||||||
|
replace_parameter(layer, "w13_weight_scale", w13_weight_scale_packed)
|
||||||
|
w2_weight_scale_packed = ops.cutlass_pack_scale_fp8(
|
||||||
|
w2_weight_scale.permute(0, 2, 1).contiguous()
|
||||||
|
)
|
||||||
|
replace_parameter(layer, "w2_weight_scale", w2_weight_scale_packed)
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> FusedMoEQuantConfig | None:
|
||||||
|
# Store quantization scales; both per-group and per-channel
|
||||||
|
# Note we haven't specified the group size here because
|
||||||
|
# the quant config logic assumes group-wise scaling
|
||||||
|
# and channel-wise scaling are exclusive.
|
||||||
|
return int4_w4afp8_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale, # group scale
|
||||||
|
w2_scale=layer.w2_weight_scale, # group scale
|
||||||
|
g1_alphas=layer.w13_weight_chan_scale,
|
||||||
|
g2_alphas=layer.w2_weight_chan_scale,
|
||||||
|
per_act_token_quant=True, # always use dynamc per-token
|
||||||
|
per_out_ch_quant=True, # always use per-channel
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
assert (
|
||||||
|
prepare_finalize.activation_format == FusedMoEActivationFormat.Standard
|
||||||
|
), "BatchedExperts not supported"
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8
|
||||||
|
|
||||||
|
experts: FusedMoEPermuteExpertsUnpermute
|
||||||
|
|
||||||
|
logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__)
|
||||||
|
experts = CutlassExpertsW4A8Fp8(
|
||||||
|
out_dtype=self.moe.in_dtype,
|
||||||
|
a_strides1=self.a_strides1_c_strides2,
|
||||||
|
a_strides2=self.a_strides2,
|
||||||
|
b_strides1=self.b_strides1,
|
||||||
|
b_strides2=self.b_strides2,
|
||||||
|
c_strides1=self.c_strides1,
|
||||||
|
c_strides2=self.a_strides1_c_strides2,
|
||||||
|
s_strides1=self.s_strides1,
|
||||||
|
s_strides2=self.s_strides2,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
group_size=self.group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_dispatchers = prepare_finalize.num_dispatchers()
|
||||||
|
self.disable_expert_map = (
|
||||||
|
num_dispatchers > 1 or not experts.supports_expert_map()
|
||||||
|
)
|
||||||
|
|
||||||
|
return experts
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
if enable_eplb:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
|
||||||
|
)
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
topk_weights, topk_ids, _ = layer.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_w4a8_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cutlass_moe_w4a8_fp8(
|
||||||
|
x,
|
||||||
|
layer.w13_weight_packed,
|
||||||
|
layer.w2_weight_packed,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
a_strides1=self.a_strides1_c_strides2,
|
||||||
|
a_strides2=self.a_strides2,
|
||||||
|
b_strides1=self.b_strides1,
|
||||||
|
b_strides2=self.b_strides2,
|
||||||
|
c_strides1=self.c_strides1,
|
||||||
|
c_strides2=self.a_strides1_c_strides2,
|
||||||
|
s_strides1=self.s_strides1,
|
||||||
|
s_strides2=self.s_strides2,
|
||||||
|
group_size=self.group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_eplb(self) -> bool:
|
||||||
|
return False
|
||||||
|
|||||||
@ -128,14 +128,15 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(czhu): allocate the packed fp8 scales memory here?
|
# After loading, we will transform bf16 -> fp8 ->
|
||||||
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
|
# expand by 8x via `cutlass_pack_scale_fp8`
|
||||||
|
# and construct per-channel fp32 scales.
|
||||||
weight_scale_args = {
|
weight_scale_args = {
|
||||||
"weight_loader": weight_loader,
|
"weight_loader": weight_loader,
|
||||||
"data": torch.empty(
|
"data": torch.empty(
|
||||||
output_size_per_partition,
|
output_size_per_partition,
|
||||||
scales_and_zp_size,
|
scales_and_zp_size,
|
||||||
dtype=torch.float8_e4m3fn,
|
dtype=params_dtype,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,17 +153,9 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
|||||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||||
)
|
)
|
||||||
|
|
||||||
# per-channel scales
|
|
||||||
weight_chan_scale = ChannelQuantScaleParameter(
|
|
||||||
data=torch.empty((output_size_per_partition, 1), dtype=torch.float32),
|
|
||||||
output_dim=0,
|
|
||||||
weight_loader=weight_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.register_parameter("weight_packed", weight)
|
layer.register_parameter("weight_packed", weight)
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
layer.register_parameter("weight_shape", weight_shape)
|
layer.register_parameter("weight_shape", weight_shape)
|
||||||
layer.register_parameter("weight_chan_scale", weight_chan_scale)
|
|
||||||
|
|
||||||
self.kernel = kernel_type(
|
self.kernel = kernel_type(
|
||||||
mp_linear_kernel_config,
|
mp_linear_kernel_config,
|
||||||
|
|||||||
@ -6,7 +6,11 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
convert_bf16_scales_to_fp8,
|
||||||
|
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||||
|
)
|
||||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
|||||||
"CUTLASS W4A8, only supported int4",
|
"CUTLASS W4A8, only supported int4",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(czhu): support -1 (column-wise)
|
|
||||||
if c.group_size != 128:
|
if c.group_size != 128:
|
||||||
return False, "Only group_size 128 is supported"
|
return False, "Only group_size 128 is supported"
|
||||||
|
|
||||||
@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
|||||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
# TODO(czhu): optimize speed/mem usage
|
|
||||||
def transform_w_q(x):
|
def transform_w_q(x):
|
||||||
assert isinstance(x, BasevLLMParameter)
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
|
||||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
|
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
|
||||||
return x
|
return x
|
||||||
@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
|||||||
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
w_s = getattr(layer, self.w_s_name)
|
||||||
|
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
|
||||||
|
w_s.data = fp8_scales
|
||||||
|
|
||||||
|
# register per-channel scales
|
||||||
|
layer.register_parameter(
|
||||||
|
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
|
||||||
|
)
|
||||||
|
|
||||||
# Encode/reorder weights and pack scales
|
# Encode/reorder weights and pack scales
|
||||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
self._transform_param(layer, "weight_chan_scale", lambda x: x)
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""This file is used for /tests and /benchmarks"""
|
"""This file is used for /tests and /benchmarks"""
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import ClassVar, NamedTuple
|
from typing import ClassVar, NamedTuple
|
||||||
@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
|
|||||||
capability_tuple = current_platform.get_device_capability()
|
capability_tuple = current_platform.get_device_capability()
|
||||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||||
return cutlass_scaled_mm_supports_fp4(capability)
|
return cutlass_scaled_mm_supports_fp4(capability)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bf16_scales_to_fp8(
|
||||||
|
quant_fp8: Callable, scales: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
|
||||||
|
expected by W4A8 GEMM kernels.
|
||||||
|
"""
|
||||||
|
assert scales.is_contiguous(), (
|
||||||
|
f"scale tensor must be contiguous, got {scales.stride()=}"
|
||||||
|
)
|
||||||
|
assert scales.is_cuda, "scales must be on gpu"
|
||||||
|
|
||||||
|
orig_shape = scales.shape
|
||||||
|
k_groups = orig_shape[-1]
|
||||||
|
flat_scales = scales.view(-1, k_groups)
|
||||||
|
|
||||||
|
fp8_scales, chan_scales = quant_fp8(flat_scales)
|
||||||
|
fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn)
|
||||||
|
chan_scales *= 8.0
|
||||||
|
|
||||||
|
# restore original shape
|
||||||
|
fp8_scales = fp8_scales.view(orig_shape)
|
||||||
|
chan_scales = chan_scales.view(orig_shape[:-1], -1)
|
||||||
|
|
||||||
|
return fp8_scales, chan_scales
|
||||||
|
|
||||||
|
|
||||||
|
def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert int4b8 (packed to int32) to signed int4
|
||||||
|
"""
|
||||||
|
assert t.is_cuda, "tensor must be on gpu"
|
||||||
|
assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}"
|
||||||
|
|
||||||
|
# loop through the 8 4-bit nibbles in each int32 entry
|
||||||
|
for i in range(8):
|
||||||
|
shift = 4 * i
|
||||||
|
# extract the i-th 4-bit nibble
|
||||||
|
nib = (t >> shift) & 0xF
|
||||||
|
# clear the original nibble by masking out
|
||||||
|
t &= ~(0xF << shift)
|
||||||
|
# convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
|
||||||
|
# and update in-place
|
||||||
|
t |= ((nib - 8) & 0xF) << shift
|
||||||
|
|
||||||
|
return t
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user