mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-21 13:22:34 +08:00
[kernel] Support W4A8 on Hopper (#23198)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
parent
a75277285b
commit
e76e233540
@ -750,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
|
||||
@ -284,6 +284,25 @@ def machete_create_bench_fn(
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_create_bench_fn(
|
||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||
) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
# expects fp8 scales
|
||||
w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))
|
||||
|
||||
return lambda: ops.cutlass_w4a8_mm(
|
||||
a=bt.a,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=bt.group_size,
|
||||
b_channel_scales=bt.w_ch_s,
|
||||
a_token_scales=bt.w_tok_s,
|
||||
maybe_schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
# bench
|
||||
@ -385,6 +404,20 @@ def bench(
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass w4a8
|
||||
if types.act_type == torch.float8_e4m3fn and group_size == 128:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"cutlass w4a8 ({name_type_string})",
|
||||
[
|
||||
cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if sweep_schedules:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
|
||||
|
||||
@ -95,4 +95,10 @@ WEIGHT_SHAPES = {
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
"CohereLabs/c4ai-command-a-03-2025": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 73728], 1),
|
||||
([36864, 12288], 0),
|
||||
],
|
||||
}
|
||||
|
||||
418
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
418
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
@ -0,0 +1,418 @@
|
||||
//
|
||||
// Based off of:
|
||||
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
||||
//
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm::cutlass_w4a8 {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Static configuration shared across all instantiations
|
||||
// -------------------------------------------------------------------------------------
|
||||
using MmaType = cutlass::float_e4m3_t; // A/scale element type
|
||||
using QuantType = cutlass::int4b_t; // B element type (packed int4)
|
||||
|
||||
static int constexpr TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
static int constexpr ScalePackSize = 8; // pack 8 scale elements together
|
||||
static int constexpr PackFactor = 8; // 8 4-bit packed into int32
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB =
|
||||
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
using LayoutB_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
// Define the CuTe layout for reordered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in
|
||||
// contiguous locations in global memory. It specifies the reordering within a
|
||||
// single warp's fragment
|
||||
using LayoutAtomQuant =
|
||||
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(
|
||||
LayoutAtomQuant{}, Layout<Shape<int, int, int>, StrideB>{}));
|
||||
|
||||
// Group-wise scales
|
||||
using ElementScale = MmaType;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// Per-tok, per-chan scales
|
||||
using ElementSChannel = float;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC =
|
||||
cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC =
|
||||
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementC>::value; // Memory access granularity/alignment of C
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
// based on the default
|
||||
// setting in the
|
||||
// Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel template — Tile/Cluster shapes
|
||||
// ----------------------------------------------------------------------------
|
||||
template <class TileShape_MN, class ClusterShape_MNK>
|
||||
struct W4A8GemmKernel {
|
||||
using TileShape =
|
||||
decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = ClusterShape_MNK;
|
||||
|
||||
// Epilogue per-tok, per-chan scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
using EVTCompute = typename ChTokScalesEpilogue::EVTCompute;
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C
|
||||
// matrix. We can enable this if beta == 0 by changing ElementC to
|
||||
// void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type,
|
||||
AlignmentC, ElementD,
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule, // This is the only epi supporting the required
|
||||
// swap + transpose.
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
// The Scale information must get paired with the operand that will be scaled.
|
||||
// In this example, B is scaled so we make a tuple of B's information and the
|
||||
// scale information.
|
||||
using CollectiveMainloopShuffled =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, ScalePackSize>>,
|
||||
LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloopShuffled, CollectiveEpilogue>;
|
||||
using GemmShuffled =
|
||||
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelShuffled::StrideC;
|
||||
using StrideD = typename GemmKernelShuffled::StrideD;
|
||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||
|
||||
static torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type) {
|
||||
// TODO: param validation
|
||||
int m = A.size(0);
|
||||
int k = A.size(1);
|
||||
int n = B.size(1);
|
||||
|
||||
// Allocate output
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto device = A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
torch::Tensor D =
|
||||
torch::empty({m, n}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<ElementD>)
|
||||
.device(device));
|
||||
// prepare arg pointers
|
||||
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<ElementD*>(D.data_ptr());
|
||||
// can we avoid harcode the 8 here
|
||||
auto S_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize> const*>(
|
||||
group_scales.const_data_ptr());
|
||||
|
||||
// runtime layout for B
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
// strides
|
||||
int const scale_k = cutlass::ceil_div(k, group_size);
|
||||
StrideA stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
// Reverse stride here due to swap and transpose
|
||||
StrideD stride_D =
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
|
||||
StrideS stride_S = cutlass::make_cute_packed_stride(
|
||||
StrideS{}, cute::make_shape(n, scale_k, 1));
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an
|
||||
// instance of Gemm auto arguments =
|
||||
// args_from_options<GemmShuffled>(options);
|
||||
/// Populates a Gemm::Arguments structure from the given arguments
|
||||
/// Swap the A and B tensors, as well as problem shapes here.
|
||||
using Args = typename GemmShuffled::Arguments;
|
||||
using MainloopArguments = typename GemmKernelShuffled::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
||||
|
||||
MainloopArguments mainloop_arguments{
|
||||
B_ptr, layout_B_reordered, A_ptr, stride_A,
|
||||
S_ptr, stride_S, group_size};
|
||||
|
||||
EpilogueArguments epilogue_arguments{
|
||||
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
|
||||
nullptr,
|
||||
{}, // no C
|
||||
D_ptr,
|
||||
stride_D};
|
||||
|
||||
Args arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{n, m, k, 1}, // shape
|
||||
mainloop_arguments,
|
||||
epilogue_arguments};
|
||||
|
||||
// Workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
CUTLASS_CHECK(gemm.run(stream));
|
||||
|
||||
return D;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel instantiations and dispatch logic
|
||||
// ----------------------------------------------------------------------------
|
||||
using Kernel_256x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x64_1x1x1 = W4A8GemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x32_1x1x1 = W4A8GemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x16_1x1x1 = W4A8GemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x256_2x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>>;
|
||||
using Kernel_128x256_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
||||
|
||||
torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "256x128_1x1x1") {
|
||||
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x64_1x1x1") {
|
||||
return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x32_1x1x1") {
|
||||
return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x16_1x1x1") {
|
||||
return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_2x1x1") {
|
||||
return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_1x1x1") {
|
||||
return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x128_1x1x1") {
|
||||
return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x64_1x1x1") {
|
||||
return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x32_1x1x1") {
|
||||
return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x16_1x1x1") {
|
||||
return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// requested a specific schedule
|
||||
if (maybe_schedule) {
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, *maybe_schedule);
|
||||
}
|
||||
std::string schedule;
|
||||
int M = A.size(0);
|
||||
int K = A.size(1);
|
||||
int N = B.size(1);
|
||||
// heuristic
|
||||
if (M <= 16) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1";
|
||||
} else if (M <= 32) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1";
|
||||
} else if (M <= 64) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x64_1x1x1";
|
||||
else if (N <= 8192 && K <= 8192)
|
||||
schedule = "128x32_1x1x1";
|
||||
else
|
||||
schedule = "128x64_1x1x1";
|
||||
} else if (M <= 128) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x128_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x64_1x1x1";
|
||||
else
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 256) {
|
||||
if (N <= 4096)
|
||||
schedule = "128x64_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x128_1x1x1";
|
||||
else
|
||||
schedule = "128x256_1x1x1";
|
||||
} else if (M <= 512 && N <= 4096) {
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 1024) {
|
||||
schedule = "128x256_1x1x1";
|
||||
} else {
|
||||
schedule = "128x256_2x1x1";
|
||||
}
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, schedule);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pre-processing utils
|
||||
// ----------------------------------------------------------------------------
|
||||
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(scales.is_cuda());
|
||||
|
||||
auto packed_scales = torch::empty(
|
||||
{scales.numel() * ScalePackSize},
|
||||
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
|
||||
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
||||
auto packed_scales_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
||||
packed_scales.data_ptr());
|
||||
|
||||
cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel());
|
||||
|
||||
return packed_scales;
|
||||
}
|
||||
|
||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(B.dim() == 2);
|
||||
|
||||
torch::Tensor B_packed = torch::empty_like(B);
|
||||
|
||||
int k = B.size(0) * PackFactor; // logical k
|
||||
int n = B.size(1);
|
||||
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
auto layout_B = make_layout(shape_B, LayoutRight{}); // row major
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
|
||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||
|
||||
return B_packed;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", &mm);
|
||||
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
|
||||
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
@ -309,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||
"SymInt size_n, int num_bits) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" Tensor group_scales,"
|
||||
" int group_size,"
|
||||
" Tensor channel_scales,"
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor",
|
||||
{stride_tag});
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
// encode and reorder weight matrix
|
||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
// Dequantization for GGML.
|
||||
|
||||
259
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
259
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the CUTLASS W4A8 kernel.
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass_w4a8.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672),
|
||||
(13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096),
|
||||
(64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096),
|
||||
(1024, 4096, 8192), (1024, 8192, 4096)]
|
||||
|
||||
# TODO(czhu): get supported schedules from fn
|
||||
SCHEDULES = [
|
||||
'128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1',
|
||||
'128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1',
|
||||
'128x256_1x1x1', '128x256_2x1x1'
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tensors:
|
||||
w_ref: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: torch.Tensor
|
||||
w_ch_s: torch.Tensor
|
||||
w_tok_s: torch.Tensor
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||
Optional[torch.dtype], bool]
|
||||
TEST_TYPES = [
|
||||
*(
|
||||
TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
weight_type=w_type,
|
||||
output_type=o_type,
|
||||
group_scale_type=torch.float8_e4m3fn,
|
||||
channel_scale_type=torch.float32,
|
||||
token_scale_type=torch.float32)
|
||||
for w_type in [scalar_types.int4]
|
||||
# TODO(czhu): fp16 out type
|
||||
for o_type in [torch.bfloat16]),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
# For testing quantized linear kernels
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min,
|
||||
max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points)
|
||||
|
||||
# since scales are cast to fp8, we need to compute w_ref this way
|
||||
w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to(
|
||||
torch.float32).repeat_interleave(group_size, dim=0)).to(atype)
|
||||
|
||||
# bit mask prevents sign extending int4 when packing
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
|
||||
w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype))
|
||||
|
||||
return w_ref, w_q_packed, w_s_packed, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
group_size: Optional[int]) -> Tensors:
|
||||
m, n, k = shape
|
||||
|
||||
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
|
||||
group_size)
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
w = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
False)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
# for the practical use case we need per-tok scales for fp8 activations
|
||||
w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type)
|
||||
# weights are already per-group quantized, use placeholder here
|
||||
w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type)
|
||||
|
||||
return Tensors(w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s)
|
||||
|
||||
|
||||
def mm_test_helper(types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None):
|
||||
# CUTLASS upstream uses fp8 with fastaccum as reference
|
||||
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
||||
output_ref = torch._scaled_mm(
|
||||
tensors.a_ref.to(types.act_type),
|
||||
tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major
|
||||
tensors.w_tok_s.unsqueeze(1),
|
||||
tensors.w_ch_s.unsqueeze(0),
|
||||
out_dtype=types.output_type,
|
||||
use_fast_accum=True)
|
||||
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=tensors.a,
|
||||
b_q=tensors.w_q,
|
||||
b_group_scales=tensors.w_g_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=tensors.w_ch_s,
|
||||
a_token_scales=tensors.w_tok_s,
|
||||
)
|
||||
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
output_ref.to(output.dtype),
|
||||
rtol=1e-3,
|
||||
atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="CUTLASS W4A8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
@pytest.mark.parametrize("schedule", SCHEDULES)
|
||||
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
|
||||
group_sizes = [128]
|
||||
for group_size in group_sizes:
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class W4A8Layer(torch.nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="CUTLASS W4A8 is not supported on this GPU type.")
|
||||
def test_w4a8_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
b = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
wtype = scalar_types.int4
|
||||
stype = torch.float8_e4m3fn
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points)
|
||||
|
||||
w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32)
|
||||
w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32)
|
||||
|
||||
# Construct a trivial model with a single layer that calls the kernel
|
||||
model = W4A8Layer(
|
||||
b_q=w_q_packed,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=w_ch_s,
|
||||
a_token_scales=w_tok_s,
|
||||
)
|
||||
|
||||
output_ref = torch._scaled_mm(
|
||||
a,
|
||||
w_ref.to(a.dtype).t().contiguous().t(), # col major
|
||||
w_tok_s.unsqueeze(1),
|
||||
w_ch_s.unsqueeze(0),
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3)
|
||||
@ -474,6 +474,30 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
return torch.empty_like(b_q_weight,
|
||||
memory_format=torch.contiguous_format)
|
||||
|
||||
@register_fake("_C::cutlass_w4a8_mm")
|
||||
def cutlass_w4a8_mm_fake(
|
||||
a: torch.Tensor,
|
||||
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
|
||||
b_q: torch.Tensor,
|
||||
b_group_scales: torch.Tensor,
|
||||
b_group_size: int,
|
||||
b_channel_scales: torch.Tensor,
|
||||
a_token_scales: torch.Tensor,
|
||||
out_type: Optional[torch.dtype] = None,
|
||||
maybe_schedule: Optional[str] = None) -> torch.Tensor:
|
||||
m = a.size(0)
|
||||
n = b_q.size(1)
|
||||
out_dtype = out_type if out_type is not None else torch.bfloat16
|
||||
return torch.empty((m, n), device=a.device, dtype=out_dtype)
|
||||
|
||||
@register_fake("_C::cutlass_pack_scale_fp8")
|
||||
def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(scales, memory_format=torch.contiguous_format)
|
||||
|
||||
@register_fake("_C::cutlass_encode_and_reorder_int4b")
|
||||
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
|
||||
|
||||
@ -1032,6 +1056,30 @@ def machete_prepack_B(
|
||||
group_scales_type)
|
||||
|
||||
|
||||
# CUTLASS W4A8
|
||||
def cutlass_w4a8_mm(
|
||||
a: torch.Tensor,
|
||||
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
|
||||
b_q: torch.Tensor,
|
||||
b_group_scales: torch.Tensor,
|
||||
b_group_size: int,
|
||||
b_channel_scales: torch.Tensor,
|
||||
a_token_scales: torch.Tensor,
|
||||
out_type: Optional[torch.dtype] = None,
|
||||
maybe_schedule: Optional[str] = None) -> torch.Tensor:
|
||||
return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size,
|
||||
b_channel_scales, a_token_scales,
|
||||
out_type, maybe_schedule)
|
||||
|
||||
|
||||
def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.cutlass_pack_scale_fp8(scales)
|
||||
|
||||
|
||||
def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "permute_cols"):
|
||||
|
||||
@register_fake("_C::permute_cols")
|
||||
|
||||
@ -26,10 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int,
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
@ -200,8 +200,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
format
|
||||
) if format is not None else is_activation_quantization_format(
|
||||
quant_format)
|
||||
if act_quant_format:
|
||||
input_activations = quant_config.get("input_activations")
|
||||
# TODO(czhu): w4a8fp8 is in packed-quantized format
|
||||
# but needs input activation quantization
|
||||
input_activations = quant_config.get("input_activations")
|
||||
if act_quant_format or input_activations:
|
||||
# The only case where we have activation quant supported
|
||||
# but no input_activations provided in the config
|
||||
# should be w8a16fp8 w8a16fp8 can also run for cases where
|
||||
@ -352,6 +354,28 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
input_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
return is_symmetric_activation and is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w4a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
if not weight_quant or not input_quant:
|
||||
return False
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
is_activation_8_bits = input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
||||
is_token = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
# Only per-group symmetric weight (4bit)
|
||||
# + per-tok symmetric activation (8bit) quantization supported.
|
||||
return (is_weight_4_bits and is_activation_8_bits and is_token
|
||||
and is_symmetric and is_dynamic)
|
||||
|
||||
def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w4a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
@ -405,6 +429,13 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
|
||||
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder)
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
@ -21,5 +22,6 @@ __all__ = [
|
||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int"
|
||||
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int",
|
||||
"CompressedTensorsW4A8Fp8"
|
||||
]
|
||||
|
||||
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Fp8"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None):
|
||||
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size != 128 or self.strategy != "group":
|
||||
raise ValueError("W4A8 kernels require group quantization " \
|
||||
"with group size 128")
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
|
||||
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# hopper
|
||||
return 90
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_type,
|
||||
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Fp8",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition //
|
||||
self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
))
|
||||
|
||||
# TODO(czhu): allocate the packed fp8 scales memory here?
|
||||
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
|
||||
BitBLASLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
||||
ConchLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
|
||||
CutlassW4A8LinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
|
||||
Dynamic4bitLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
@ -24,6 +26,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
|
||||
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# dynamic per-tok fp8 activation quantization
|
||||
self.quant_fp8 = QuantFP8(static=False,
|
||||
group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "CUTLASS W4A8 requires compute capability of 90 "\
|
||||
"(Hopper)"
|
||||
|
||||
if c.act_type != torch.float8_e4m3fn:
|
||||
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
|
||||
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering not supported by CUTLASS W4A8"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points not supported by CUTLASS W4A8"
|
||||
|
||||
if c.weight_type != scalar_types.int4:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"CUTLASS W4A8, only supported int4"
|
||||
|
||||
# TODO(czhu): support -1 (column-wise)
|
||||
if c.group_size != 128:
|
||||
return False, "Only group_size 128 is supported"
|
||||
|
||||
in_features, out_features = c.partition_weight_shape
|
||||
if in_features % 128 or out_features % 128:
|
||||
return False, "K and N must be divisible by 128, got "\
|
||||
f"{c.partition_weight_shape}"
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
# TODO(czhu): optimize speed/mem usage
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(
|
||||
x.data.t().contiguous().t())
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
|
||||
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
||||
return x
|
||||
|
||||
# Encode/reorder weights and pack scales
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
# TODO(czhu): support loading channel scales
|
||||
self.w_ch_s = torch.ones((c.partition_weight_shape[1], ),
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert bias is None, "bias not supported by CUTLASS W4A8"
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
x_2d, act_scales = self.quant_fp8(x_2d)
|
||||
output = ops.cutlass_w4a8_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=self.w_ch_s)
|
||||
|
||||
return output.reshape(out_shape)
|
||||
Loading…
x
Reference in New Issue
Block a user