mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 06:05:01 +08:00
Sm100 blockwise fp8 swap ab (#18564)
This commit is contained in:
parent
02658c2dfe
commit
5f2cd251d2
@ -9,10 +9,6 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
TORCH_CHECK(
|
|
||||||
a.size(0) % 4 == 0,
|
|
||||||
"Input tensor must have a number of rows that is a multiple of 4. ",
|
|
||||||
"but got: ", a.size(0), " rows.");
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "cuda_utils.h"
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "cutlass/numeric_types.h"
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
@ -22,49 +23,49 @@ namespace vllm {
|
|||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
|
// clang-format off
|
||||||
class ClusterShape, typename EpilogueScheduler,
|
template <class OutType, int ScaleGranularityM,
|
||||||
typename MainloopScheduler>
|
int ScaleGranularityN, int ScaleGranularityK,
|
||||||
|
class MmaTileShape, class ClusterShape,
|
||||||
|
class EpilogueScheduler, class MainloopScheduler,
|
||||||
|
bool swap_ab_ = false>
|
||||||
struct cutlass_3x_gemm_fp8_blockwise {
|
struct cutlass_3x_gemm_fp8_blockwise {
|
||||||
|
static constexpr bool swap_ab = swap_ab_;
|
||||||
using ElementAB = cutlass::float_e4m3_t;
|
using ElementAB = cutlass::float_e4m3_t;
|
||||||
|
|
||||||
using ElementA = ElementAB;
|
using ElementA = ElementAB;
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
|
||||||
using ElementB = ElementAB;
|
using ElementB = ElementAB;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
|
||||||
using ElementC = void;
|
|
||||||
using ElementD = OutType;
|
using ElementD = OutType;
|
||||||
using LayoutD = cutlass::layout::RowMajor;
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
using ElementC = void; // TODO: support bias
|
||||||
using LayoutC = LayoutD;
|
using LayoutC = LayoutD;
|
||||||
|
using LayoutC_Transpose = LayoutD_Transpose;
|
||||||
static constexpr int AlignmentC = AlignmentD;
|
static constexpr int AlignmentC = AlignmentD;
|
||||||
|
|
||||||
using ElementAccumulator = float;
|
using ElementAccumulator = float;
|
||||||
using ElementCompute = float;
|
using ElementCompute = float;
|
||||||
using ElementBlockScale = float;
|
using ElementBlockScale = float;
|
||||||
|
|
||||||
// MMA and Cluster Tile Shapes
|
using ScaleConfig = conditional_t<swap_ab,
|
||||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
|
cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||||
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
|
|
||||||
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
|
|
||||||
static constexpr int ScaleGranularityM =
|
|
||||||
size<0>(MmaTileShape{}) / ScaleMsPerTile;
|
|
||||||
static constexpr int ScaleGranularityN =
|
|
||||||
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
|
|
||||||
static constexpr int ScaleGranularityK =
|
|
||||||
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
|
|
||||||
|
|
||||||
// Shape of the threadblocks in a cluster
|
|
||||||
using ClusterShape_MNK = ClusterShape;
|
|
||||||
|
|
||||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
|
||||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
|
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
|
||||||
|
cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||||
|
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||||
|
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
|
||||||
|
|
||||||
|
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
|
||||||
@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
|
|
||||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||||
using ElementScalar = float;
|
using ElementScalar = float;
|
||||||
// clang-format off
|
|
||||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
@ -84,17 +84,33 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementCompute,
|
ElementCompute,
|
||||||
ElementC,
|
ElementC,
|
||||||
LayoutC,
|
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
|
||||||
AlignmentC,
|
AlignmentC,
|
||||||
ElementD,
|
ElementD,
|
||||||
LayoutD,
|
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
|
||||||
AlignmentD,
|
AlignmentD,
|
||||||
EpilogueScheduler,
|
EpilogueScheduler,
|
||||||
DefaultOperation
|
DefaultOperation
|
||||||
>::CollectiveOp;
|
>::CollectiveOp;
|
||||||
|
|
||||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
using CollectiveMainloop = conditional_t<swap_ab,
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
ElementB,
|
||||||
|
cute::tuple<LayoutB_Transpose, LayoutSFA>,
|
||||||
|
AlignmentB,
|
||||||
|
ElementA,
|
||||||
|
cute::tuple<LayoutA_Transpose, LayoutSFB>,
|
||||||
|
AlignmentA,
|
||||||
|
ElementAccumulator,
|
||||||
|
MmaTileShape,
|
||||||
|
ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
MainloopScheduler
|
||||||
|
>::CollectiveOp,
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
OperatorClass,
|
OperatorClass,
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -106,11 +122,9 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
MmaTileShape,
|
MmaTileShape,
|
||||||
ClusterShape,
|
ClusterShape,
|
||||||
|
|
||||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
MainloopScheduler
|
MainloopScheduler
|
||||||
>::CollectiveOp;
|
>::CollectiveOp>;
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
|
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
|
||||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||||
@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
|
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
|
||||||
|
|
||||||
StrideA a_stride;
|
StrideA a_stride;
|
||||||
StrideB b_stride;
|
StrideB b_stride;
|
||||||
@ -146,11 +160,13 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
b_stride =
|
b_stride =
|
||||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||||
c_stride =
|
c_stride =
|
||||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
|
||||||
|
|
||||||
LayoutSFA layout_SFA =
|
LayoutSFA layout_SFA = swap_ab ?
|
||||||
|
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
|
||||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||||
LayoutSFB layout_SFB =
|
LayoutSFB layout_SFB = swap_ab ?
|
||||||
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
|
||||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
@ -158,9 +174,22 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||||
|
|
||||||
typename GemmKernel::MainloopArguments mainloop_args{
|
auto mainloop_args = [&](){
|
||||||
|
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||||
|
if (swap_ab) {
|
||||||
|
return typename GemmKernel::MainloopArguments{
|
||||||
|
b_ptr, b_stride, a_ptr, a_stride,
|
||||||
|
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
|
||||||
|
};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return typename GemmKernel::MainloopArguments{
|
||||||
a_ptr, a_stride, b_ptr, b_stride,
|
a_ptr, a_stride, b_ptr, b_stride,
|
||||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
|
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
@ -175,29 +204,74 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
auto m = a.size(0);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
|
||||||
auto k = a.size(1);
|
|
||||||
auto n = b.size(1);
|
|
||||||
int sms;
|
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||||
|
|
||||||
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
|
constexpr int TILE_K = 128;
|
||||||
return std::ceil(static_cast<float>(m) / tile1SM) *
|
// TODO: better heuristics
|
||||||
std::ceil(static_cast<float>(n) / tile1SM) >=
|
bool swap_ab = (m < 16) || (m % 4 != 0);
|
||||||
sms;
|
bool use_tma_epilogue = (m * n) % 4 == 0;
|
||||||
};
|
if (!swap_ab) {
|
||||||
bool use_2sm = should_use_2sm(m, n);
|
constexpr int TILE_N = 128;
|
||||||
if (use_2sm) {
|
int tile_m = 256;
|
||||||
|
if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) {
|
||||||
|
tile_m = 64;
|
||||||
|
}
|
||||||
|
else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) {
|
||||||
|
tile_m = 128;
|
||||||
|
}
|
||||||
|
if (tile_m == 64) {
|
||||||
|
if (use_tma_epilogue) {
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
|
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
|
||||||
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
|
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
|
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
|
||||||
|
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
} else if (tile_m == 128) {
|
||||||
|
if (use_tma_epilogue) {
|
||||||
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
|
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
|
||||||
|
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
|
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
|
||||||
|
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
} else { // tile_m == 256
|
||||||
|
if (use_tma_epilogue) {
|
||||||
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
|
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
|
||||||
|
Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
|
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
|
||||||
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
|
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// TODO: Test more tile N configs
|
||||||
|
constexpr int TILE_M = 128;
|
||||||
|
constexpr int TILE_N = 16;
|
||||||
|
// TMA epilogue isn't compatible with Swap A/B
|
||||||
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
|
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
|
||||||
|
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -136,24 +136,10 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
use_cutlass, use_aiter_and_is_supported)
|
use_cutlass, use_aiter_and_is_supported)
|
||||||
|
|
||||||
if use_cutlass:
|
if use_cutlass:
|
||||||
rows, cols = input_2d.shape
|
|
||||||
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
|
|
||||||
# optimal tensor core usage. Can be removed when targeting platforms
|
|
||||||
# without this constraint.
|
|
||||||
should_pad = current_platform.has_device_capability(
|
|
||||||
100) and rows % 4 != 0
|
|
||||||
if should_pad:
|
|
||||||
input_2d = torch.nn.functional.pad(input_2d,
|
|
||||||
(0, 0, 0, 4 - (rows % 4)),
|
|
||||||
value=0).contiguous()
|
|
||||||
|
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||||
|
|
||||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||||
block_size, input.dtype)
|
block_size, input.dtype)
|
||||||
if should_pad:
|
|
||||||
output = output[:rows, :]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user