mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[Perf] Use upstream CUTLASS for SM90 Block FP8 kernel (#23280)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
d4fd2768ef
commit
c3aea10dc8
@ -4,7 +4,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
w8a8_block_fp8_matmul,
|
apply_w8a8_block_fp8_linear,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import triton as vllm_triton
|
from vllm.triton_utils import triton as vllm_triton
|
||||||
@ -29,7 +32,7 @@ DEEPSEEK_V3_SHAPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_w8a8_block_fp8_runner(M, N, K, block_size, device):
|
def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||||
"""Build runner function for w8a8 block fp8 matmul."""
|
"""Build runner function for w8a8 block fp8 matmul."""
|
||||||
factor_for_scale = 1e-2
|
factor_for_scale = 1e-2
|
||||||
|
|
||||||
@ -37,37 +40,54 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device):
|
|||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
# Create random FP8 tensors
|
# Create random FP8 tensors
|
||||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
# Create scales
|
# Create scales
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
k_tiles = (K + block_k - 1) // block_k
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale
|
|
||||||
Bs = (
|
Bs = (
|
||||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device)
|
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device)
|
||||||
* factor_for_scale
|
* factor_for_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SM90 CUTLASS requires row-major format for scales
|
||||||
|
if use_cutlass and current_platform.is_device_capability(90):
|
||||||
|
Bs = Bs.T.contiguous()
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16)
|
if use_cutlass:
|
||||||
|
return apply_w8a8_block_fp8_linear(
|
||||||
|
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return apply_w8a8_block_fp8_linear(
|
||||||
|
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
|
||||||
|
)
|
||||||
|
|
||||||
return run
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
# Determine available providers
|
||||||
|
available_providers = ["torch-bf16", "w8a8-block-fp8-triton"]
|
||||||
|
plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
|
||||||
|
|
||||||
|
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||||
|
available_providers.append("w8a8-block-fp8-cutlass")
|
||||||
|
|
||||||
|
|
||||||
@vllm_triton.testing.perf_report(
|
@vllm_triton.testing.perf_report(
|
||||||
vllm_triton.testing.Benchmark(
|
vllm_triton.testing.Benchmark(
|
||||||
x_names=["batch_size"],
|
x_names=["batch_size"],
|
||||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||||
x_log=False,
|
x_log=False,
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["torch-bf16", "w8a8-block-fp8"],
|
line_vals=available_providers,
|
||||||
line_names=["torch-bf16", "w8a8-block-fp8"],
|
line_names=available_providers,
|
||||||
ylabel="TFLOP/s (larger is better)",
|
ylabel="TFLOP/s (larger is better)",
|
||||||
plot_name="BF16 vs W8A8 Block FP8 GEMMs",
|
plot_name="BF16 vs W8A8 Block FP8 GEMMs",
|
||||||
args={},
|
args={},
|
||||||
@ -85,11 +105,22 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
|
|||||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||||
)
|
)
|
||||||
else: # w8a8-block-fp8
|
elif provider == "w8a8-block-fp8-triton":
|
||||||
run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device)
|
run_w8a8_triton = build_w8a8_block_fp8_runner(
|
||||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
M, N, K, block_size, device, use_cutlass=False
|
||||||
lambda: run_w8a8(), quantiles=quantiles
|
|
||||||
)
|
)
|
||||||
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_w8a8_triton(), quantiles=quantiles
|
||||||
|
)
|
||||||
|
elif provider == "w8a8-block-fp8-cutlass":
|
||||||
|
run_w8a8_cutlass = build_w8a8_block_fp8_runner(
|
||||||
|
M, N, K, block_size, device, use_cutlass=True
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|||||||
@ -1,123 +0,0 @@
|
|||||||
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
|
|
||||||
// clang-format off
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
|
|
||||||
|
|
||||||
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
|
||||||
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace cutlass::gemm::collective {
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
|
||||||
template <
|
|
||||||
class ElementA,
|
|
||||||
class GmemLayoutATag,
|
|
||||||
int AlignmentA,
|
|
||||||
class ElementB,
|
|
||||||
class GmemLayoutBTag,
|
|
||||||
int AlignmentB,
|
|
||||||
class ElementAccumulator,
|
|
||||||
class TileShape_MNK,
|
|
||||||
class ClusterShape_MNK,
|
|
||||||
class StageCountType,
|
|
||||||
int ScaleGranularityM
|
|
||||||
>
|
|
||||||
struct CollectiveBuilder<
|
|
||||||
arch::Sm90,
|
|
||||||
arch::OpClassTensorOp,
|
|
||||||
ElementA,
|
|
||||||
GmemLayoutATag,
|
|
||||||
AlignmentA,
|
|
||||||
ElementB,
|
|
||||||
GmemLayoutBTag,
|
|
||||||
AlignmentB,
|
|
||||||
ElementAccumulator,
|
|
||||||
TileShape_MNK,
|
|
||||||
ClusterShape_MNK,
|
|
||||||
StageCountType,
|
|
||||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
|
||||||
cute::enable_if_t<
|
|
||||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
|
||||||
> {
|
|
||||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
|
||||||
|
|
||||||
static_assert(is_static<TileShape_MNK>::value);
|
|
||||||
static_assert(is_static<ClusterShape_MNK>::value);
|
|
||||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
|
||||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
|
||||||
#endif
|
|
||||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
|
||||||
"Should meet TMA alignment requirement\n");
|
|
||||||
|
|
||||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
|
||||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
|
||||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
|
||||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
|
||||||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
|
||||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
|
||||||
|
|
||||||
// For fp32 types, map to tf32 MMA value type
|
|
||||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
|
||||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
|
||||||
|
|
||||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
|
||||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
|
||||||
|
|
||||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
|
||||||
KernelTmaWarpSpecializedCooperative,
|
|
||||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
|
||||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
|
||||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
|
||||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
|
||||||
|
|
||||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
|
||||||
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
|
||||||
|
|
||||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
|
||||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
|
||||||
|
|
||||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
|
||||||
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
|
||||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
|
||||||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
|
||||||
|
|
||||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
|
||||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
|
||||||
|
|
||||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
|
||||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
|
||||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
|
||||||
|
|
||||||
using SmemCopyAtomA = void;
|
|
||||||
using SmemCopyAtomB = void;
|
|
||||||
|
|
||||||
using CollectiveOp = CollectiveMma<
|
|
||||||
DispatchPolicy,
|
|
||||||
TileShape_MNK,
|
|
||||||
ElementA,
|
|
||||||
TagToStrideA_t<GmemLayoutATag>,
|
|
||||||
ElementB,
|
|
||||||
TagToStrideB_t<GmemLayoutBTag>,
|
|
||||||
TiledMma,
|
|
||||||
GmemTiledCopyA,
|
|
||||||
SmemLayoutAtomA,
|
|
||||||
SmemCopyAtomA,
|
|
||||||
cute::identity,
|
|
||||||
GmemTiledCopyB,
|
|
||||||
SmemLayoutAtomB,
|
|
||||||
SmemCopyAtomB,
|
|
||||||
cute::identity
|
|
||||||
>;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
} // namespace cutlass::gemm::collective
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
@ -1,183 +0,0 @@
|
|||||||
// clang-format off
|
|
||||||
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
|
|
||||||
|
|
||||||
/***************************************************************************************************
|
|
||||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
||||||
* SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
*
|
|
||||||
* Redistribution and use in source and binary forms, with or without
|
|
||||||
* modification, are permitted provided that the following conditions are met:
|
|
||||||
*
|
|
||||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
* list of conditions and the following disclaimer.
|
|
||||||
*
|
|
||||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
* this list of conditions and the following disclaimer in the documentation
|
|
||||||
* and/or other materials provided with the distribution.
|
|
||||||
*
|
|
||||||
* 3. Neither the name of the copyright holder nor the names of its
|
|
||||||
* contributors may be used to endorse or promote products derived from
|
|
||||||
* this software without specific prior written permission.
|
|
||||||
*
|
|
||||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
*
|
|
||||||
**************************************************************************************************/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cute/algorithm/clear.hpp"
|
|
||||||
#include "cute/tensor.hpp"
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
///////////////////////////////////FP8 Accumulation///////////////////////////
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
/// This class provides API to promote (add) or scale (multiply_add) the results
|
|
||||||
/// from the tensor core accumulators to the main accumulators when the number
|
|
||||||
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
|
||||||
/// the tensor core accumulators are zeroed.
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace cutlass::gemm::collective {
|
|
||||||
|
|
||||||
template <
|
|
||||||
class EngineAccum,
|
|
||||||
class LayoutAccum>
|
|
||||||
struct GmmaFP8AccumulationWithScale {
|
|
||||||
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
|
||||||
using ElementAccumulator = typename EngineAccum::value_type;
|
|
||||||
|
|
||||||
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
|
|
||||||
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
|
|
||||||
|
|
||||||
private:
|
|
||||||
TensorAccum& accum_;
|
|
||||||
TensorAccum accum_temp_;
|
|
||||||
|
|
||||||
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
|
||||||
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
|
||||||
uint32_t mma_count_; // current executed MMAs
|
|
||||||
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
|
||||||
|
|
||||||
// promote or `add` the partial accumulators to main accumulator (FADD).
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void promote_core() {
|
|
||||||
warpgroup_wait<0>();
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(accum_); ++i) {
|
|
||||||
accum_(i) += accum_temp_(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
|
||||||
template <
|
|
||||||
class EngineScale,
|
|
||||||
class LayoutScale>
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
|
||||||
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
|
|
||||||
|
|
||||||
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
|
|
||||||
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
|
|
||||||
|
|
||||||
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
|
|
||||||
|
|
||||||
warpgroup_wait<0>();
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(accum_); ++i) {
|
|
||||||
accum_(i) += accum_temp_(i) * scale(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
GmmaFP8AccumulationWithScale(
|
|
||||||
TensorAccum &accum,
|
|
||||||
uint32_t accum_promotion_interval,
|
|
||||||
uint32_t mma_count_per_mainloop_iteration)
|
|
||||||
: accum_(accum),
|
|
||||||
accum_promotion_interval_(accum_promotion_interval),
|
|
||||||
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
|
||||||
mma_count_(0),
|
|
||||||
reset_accum_flag_(0)
|
|
||||||
{
|
|
||||||
accum_temp_ = cute::make_fragment_like(accum);
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Methods (Common)
|
|
||||||
//
|
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
TensorAccum& operator()() {
|
|
||||||
return accum_temp_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// prepare the MMA accumulators when initialization or zeroing is required.
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
bool prepare_if_needed() {
|
|
||||||
return reset_accum_flag_;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Methods (for FADD version)
|
|
||||||
//
|
|
||||||
|
|
||||||
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void promote_if_needed() {
|
|
||||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
|
||||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
|
||||||
if (reset_accum_flag_) {
|
|
||||||
promote_core();
|
|
||||||
mma_count_ = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void promote_residue_if_needed() {
|
|
||||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
|
||||||
promote_core();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Methods (for FFMA version)
|
|
||||||
//
|
|
||||||
|
|
||||||
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
|
||||||
template <
|
|
||||||
class EngineScale,
|
|
||||||
class LayoutScale>
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
|
||||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
|
||||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
|
||||||
if (reset_accum_flag_) {
|
|
||||||
scale_core(scale);
|
|
||||||
mma_count_ = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
|
||||||
template <
|
|
||||||
class EngineScale,
|
|
||||||
class LayoutScale>
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
|
||||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
|
||||||
scale_core(scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace cutlass::gemm::collective
|
|
||||||
@ -1,729 +0,0 @@
|
|||||||
// clang-format off
|
|
||||||
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
|
||||||
|
|
||||||
/***************************************************************************************************
|
|
||||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
||||||
* SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
*
|
|
||||||
* Redistribution and use in source and binary forms, with or without
|
|
||||||
* modification, are permitted provided that the following conditions are met:
|
|
||||||
*
|
|
||||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
* list of conditions and the following disclaimer.
|
|
||||||
*
|
|
||||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
* this list of conditions and the following disclaimer in the documentation
|
|
||||||
* and/or other materials provided with the distribution.
|
|
||||||
*
|
|
||||||
* 3. Neither the name of the copyright holder nor the names of its
|
|
||||||
* contributors may be used to endorse or promote products derived from
|
|
||||||
* this software without specific prior written permission.
|
|
||||||
*
|
|
||||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
*
|
|
||||||
**************************************************************************************************/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
|
||||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
||||||
#include "cutlass/trace.h"
|
|
||||||
#include "cutlass/numeric_types.h"
|
|
||||||
|
|
||||||
#include "cute/arch/cluster_sm90.hpp"
|
|
||||||
#include "cute/arch/copy_sm80.hpp"
|
|
||||||
#include "cute/arch/copy_sm90.hpp"
|
|
||||||
#include "cute/algorithm/functional.hpp"
|
|
||||||
#include "cute/atom/mma_atom.hpp"
|
|
||||||
#include "cute/algorithm/gemm.hpp"
|
|
||||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
|
||||||
|
|
||||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
|
||||||
#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace cutlass::gemm::collective {
|
|
||||||
using namespace cute;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// WarpSpecialized Mainloop
|
|
||||||
template <
|
|
||||||
int Stages,
|
|
||||||
class ClusterShape,
|
|
||||||
class KernelSchedule,
|
|
||||||
int ScaleGranularityM_,
|
|
||||||
class TileShape_,
|
|
||||||
class ElementA_,
|
|
||||||
class StrideA_,
|
|
||||||
class ElementB_,
|
|
||||||
class StrideB_,
|
|
||||||
class TiledMma_,
|
|
||||||
class GmemTiledCopyA_,
|
|
||||||
class SmemLayoutAtomA_,
|
|
||||||
class SmemCopyAtomA_,
|
|
||||||
class TransformA_,
|
|
||||||
class GmemTiledCopyB_,
|
|
||||||
class SmemLayoutAtomB_,
|
|
||||||
class SmemCopyAtomB_,
|
|
||||||
class TransformB_>
|
|
||||||
struct CollectiveMma<
|
|
||||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
|
||||||
TileShape_,
|
|
||||||
ElementA_,
|
|
||||||
StrideA_,
|
|
||||||
ElementB_,
|
|
||||||
StrideB_,
|
|
||||||
TiledMma_,
|
|
||||||
GmemTiledCopyA_,
|
|
||||||
SmemLayoutAtomA_,
|
|
||||||
SmemCopyAtomA_,
|
|
||||||
TransformA_,
|
|
||||||
GmemTiledCopyB_,
|
|
||||||
SmemLayoutAtomB_,
|
|
||||||
SmemCopyAtomB_,
|
|
||||||
TransformB_>
|
|
||||||
{
|
|
||||||
//
|
|
||||||
// Type Aliases
|
|
||||||
//
|
|
||||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
|
||||||
using TileShape = TileShape_;
|
|
||||||
using ElementA = ElementA_;
|
|
||||||
using StrideA = StrideA_;
|
|
||||||
using ElementB = ElementB_;
|
|
||||||
using StrideB = StrideB_;
|
|
||||||
using TiledMma = TiledMma_;
|
|
||||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
|
||||||
using ElementBlockScale = ElementAccumulator;
|
|
||||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
|
||||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
|
||||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
|
||||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
|
||||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
|
||||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
|
||||||
using TransformA = TransformA_;
|
|
||||||
using TransformB = TransformB_;
|
|
||||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
|
||||||
|
|
||||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
|
||||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
|
||||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
|
||||||
using PipelineParams = typename MainloopPipeline::Params;
|
|
||||||
|
|
||||||
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
|
||||||
static constexpr int NumProducerThreadEvents = 33;
|
|
||||||
|
|
||||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
|
||||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
|
||||||
|
|
||||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
|
||||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
|
||||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
|
||||||
|
|
||||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
|
||||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
|
||||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
|
||||||
|
|
||||||
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
|
||||||
|
|
||||||
// Tile along modes in a way that maximizes the TMA box size.
|
|
||||||
using SmemLayoutA = decltype(tile_to_shape(
|
|
||||||
SmemLayoutAtomA{},
|
|
||||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
|
||||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
|
||||||
using SmemLayoutB = decltype(tile_to_shape(
|
|
||||||
SmemLayoutAtomB{},
|
|
||||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
|
||||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
|
||||||
|
|
||||||
// Block scaling gmem-to-smem copy atom
|
|
||||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
|
||||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
|
||||||
|
|
||||||
// Block scaling smem layout
|
|
||||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
|
||||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
|
||||||
|
|
||||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
|
||||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
|
||||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
|
||||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
|
||||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
|
||||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
|
||||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
|
||||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
|
||||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
|
||||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
|
||||||
|
|
||||||
struct SharedStorage
|
|
||||||
{
|
|
||||||
struct TensorStorage : cute::aligned_struct<128> {
|
|
||||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
|
||||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
|
||||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
|
||||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
|
||||||
} tensors;
|
|
||||||
|
|
||||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
|
||||||
PipelineStorage pipeline;
|
|
||||||
};
|
|
||||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
|
||||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
|
||||||
|
|
||||||
// Host side kernel arguments
|
|
||||||
struct Arguments {
|
|
||||||
ElementA const* ptr_A;
|
|
||||||
StrideA dA;
|
|
||||||
ElementB const* ptr_B;
|
|
||||||
StrideB dB;
|
|
||||||
ElementBlockScale const* ptr_scale_A;
|
|
||||||
ElementBlockScale const* ptr_scale_B;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Device side kernel params
|
|
||||||
struct Params {
|
|
||||||
// Assumption: StrideA is congruent with Problem_MK
|
|
||||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
|
||||||
GmemTiledCopyA{},
|
|
||||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
|
||||||
SmemLayoutA{}(_,_,0),
|
|
||||||
TileShape{},
|
|
||||||
ClusterShape{}));
|
|
||||||
// Assumption: StrideB is congruent with Problem_NK
|
|
||||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
|
||||||
GmemTiledCopyB{},
|
|
||||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
|
||||||
SmemLayoutB{}(_,_,0),
|
|
||||||
TileShape{},
|
|
||||||
ClusterShape{}));
|
|
||||||
TMA_A tma_load_a;
|
|
||||||
TMA_B tma_load_b;
|
|
||||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
|
||||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
|
||||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
|
||||||
// Block scaling factors for A and B
|
|
||||||
ElementBlockScale const* ptr_scale_A;
|
|
||||||
ElementBlockScale const* ptr_scale_B;
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
|
||||||
// Methods
|
|
||||||
//
|
|
||||||
|
|
||||||
template <class ProblemShape>
|
|
||||||
static constexpr Params
|
|
||||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
|
||||||
(void) workspace;
|
|
||||||
|
|
||||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
|
||||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
|
||||||
auto [M,N,K,L] = problem_shape_MNKL;
|
|
||||||
|
|
||||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
|
||||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
|
||||||
|
|
||||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
|
||||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
|
||||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
|
||||||
GmemTiledCopyA{},
|
|
||||||
tensor_a,
|
|
||||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
|
||||||
TileShape{},
|
|
||||||
ClusterShape{});
|
|
||||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
|
||||||
GmemTiledCopyB{},
|
|
||||||
tensor_b,
|
|
||||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
|
||||||
TileShape{},
|
|
||||||
ClusterShape{});
|
|
||||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
|
||||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
|
||||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
|
||||||
|
|
||||||
return {
|
|
||||||
tma_load_a,
|
|
||||||
tma_load_b,
|
|
||||||
transaction_bytes,
|
|
||||||
transaction_bytes_mk,
|
|
||||||
transaction_bytes_nk,
|
|
||||||
args.ptr_scale_A,
|
|
||||||
args.ptr_scale_B
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class ProblemShape>
|
|
||||||
static bool
|
|
||||||
can_implement(
|
|
||||||
ProblemShape const& problem_shape,
|
|
||||||
[[maybe_unused]] Arguments const& args) {
|
|
||||||
constexpr int tma_alignment_bits = 128;
|
|
||||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
|
||||||
auto [M,N,K,L] = problem_shape_MNKL;
|
|
||||||
|
|
||||||
bool implementable = true;
|
|
||||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
|
||||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
|
||||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
|
||||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
|
||||||
|
|
||||||
if (!implementable) {
|
|
||||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
|
||||||
}
|
|
||||||
return implementable;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
|
||||||
static constexpr int K_PIPE_MMAS = 1;
|
|
||||||
static constexpr uint32_t TmaTransactionBytesMK =
|
|
||||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
|
||||||
static constexpr uint32_t TmaTransactionBytesNK =
|
|
||||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
|
||||||
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
|
||||||
|
|
||||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
|
||||||
{
|
|
||||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
|
||||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set up the data needed by this collective for load and mma.
|
|
||||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
|
||||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
|
||||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
|
||||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
|
||||||
template <class ProblemShape_MNKL>
|
|
||||||
CUTLASS_DEVICE auto
|
|
||||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
|
||||||
using X = Underscore;
|
|
||||||
// Separate out problem shape for convenience
|
|
||||||
auto [M,N,K,L] = problem_shape_MNKL;
|
|
||||||
|
|
||||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
|
||||||
// Represent the full tensors -- get these from TMA
|
|
||||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
|
||||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
|
||||||
|
|
||||||
// Make tiled views, defer the slice
|
|
||||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
|
||||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
|
||||||
|
|
||||||
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
|
||||||
auto tM = get<2>(gA_mkl.shape());
|
|
||||||
auto tN = get<2>(gB_nkl.shape());
|
|
||||||
auto tK = get<3>(gA_mkl.shape());
|
|
||||||
|
|
||||||
// Make the tiled views of scale tensors
|
|
||||||
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
|
||||||
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
|
||||||
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
|
||||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
|
||||||
|
|
||||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
|
||||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
|
||||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
|
||||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
|
||||||
|
|
||||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Perform a collective-scoped matrix multiply-accumulate
|
|
||||||
/// Producer Perspective
|
|
||||||
template <
|
|
||||||
class TensorA, class TensorB,
|
|
||||||
class TensorScaleA, class TensorScaleB,
|
|
||||||
class KTileIterator, class BlockCoord
|
|
||||||
>
|
|
||||||
CUTLASS_DEVICE void
|
|
||||||
load(
|
|
||||||
Params const& mainloop_params,
|
|
||||||
MainloopPipeline pipeline,
|
|
||||||
PipelineState smem_pipe_write,
|
|
||||||
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
|
||||||
BlockCoord const& blk_coord,
|
|
||||||
KTileIterator k_tile_iter, int k_tile_count,
|
|
||||||
int thread_idx,
|
|
||||||
uint32_t block_rank_in_cluster,
|
|
||||||
TensorStorage& shared_tensors) {
|
|
||||||
int lane_predicate = cute::elect_one_sync();
|
|
||||||
|
|
||||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
|
||||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
|
||||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
|
||||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
|
||||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
|
||||||
|
|
||||||
//
|
|
||||||
// Prepare the TMA loads for A and B
|
|
||||||
//
|
|
||||||
|
|
||||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
|
||||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
|
||||||
|
|
||||||
Tensor gA_mkl = get<0>(load_inputs);
|
|
||||||
Tensor gB_nkl = get<1>(load_inputs);
|
|
||||||
|
|
||||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
|
||||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
|
||||||
|
|
||||||
// Partition the inputs based on the current block coordinates.
|
|
||||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
|
||||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
|
||||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
|
||||||
|
|
||||||
|
|
||||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
|
||||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
|
||||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
|
||||||
auto scales_m = get<0>(mScaleA_mkl.shape());
|
|
||||||
|
|
||||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
|
||||||
|
|
||||||
Tensor gScaleA = local_tile(
|
|
||||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
|
||||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
|
||||||
Tensor cScaleA = local_tile(
|
|
||||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
|
||||||
make_coord(m_coord,_,l_coord));
|
|
||||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
|
||||||
|
|
||||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
|
||||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
|
||||||
Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
|
||||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
|
||||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
|
||||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
|
||||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
|
||||||
|
|
||||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
|
||||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
|
||||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
|
||||||
|
|
||||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
|
||||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
|
||||||
|
|
||||||
// Applies the mapping from block_tma_a
|
|
||||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
|
||||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
|
||||||
|
|
||||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
|
||||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
|
||||||
|
|
||||||
uint16_t mcast_mask_a = 0;
|
|
||||||
uint16_t mcast_mask_b = 0;
|
|
||||||
|
|
||||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
|
||||||
// Maps the tile -> block, value
|
|
||||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
|
||||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
|
||||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
|
||||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
|
||||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
|
||||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
|
||||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
|
||||||
// all scales are valid, since we could have a partial tiles along M)
|
|
||||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
|
||||||
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mainloop
|
|
||||||
CUTLASS_PRAGMA_NO_UNROLL
|
|
||||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
|
||||||
// LOCK smem_pipe_write for _writing_
|
|
||||||
pipeline.producer_acquire(smem_pipe_write);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Copy gmem to smem for *k_tile_iter
|
|
||||||
//
|
|
||||||
int write_stage = smem_pipe_write.index();
|
|
||||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
|
||||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
|
||||||
|
|
||||||
// Copy operands A and B from global memory to shared memory
|
|
||||||
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
|
||||||
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
|
||||||
|
|
||||||
// Copy scale tensors from global memory to shared memory
|
|
||||||
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
|
||||||
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
|
||||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
|
||||||
|
|
||||||
++k_tile_iter;
|
|
||||||
|
|
||||||
// Advance smem_pipe_write
|
|
||||||
++smem_pipe_write;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
|
||||||
CUTLASS_DEVICE void
|
|
||||||
load_tail(
|
|
||||||
MainloopPipeline pipeline,
|
|
||||||
PipelineState smem_pipe_write) {
|
|
||||||
int lane_predicate = cute::elect_one_sync();
|
|
||||||
|
|
||||||
// Issue the epilogue waits
|
|
||||||
if (lane_predicate) {
|
|
||||||
/* This helps avoid early exit of blocks in Cluster
|
|
||||||
* Waits for all stages to either be released (all
|
|
||||||
* Consumer UNLOCKs), or if the stage was never used
|
|
||||||
* then would just be acquired since the phase was
|
|
||||||
* still inverted from make_producer_start_state
|
|
||||||
*/
|
|
||||||
pipeline.producer_tail(smem_pipe_write);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Perform a collective-scoped matrix multiply-accumulate
|
|
||||||
/// Consumer Perspective
|
|
||||||
template <
|
|
||||||
class FrgTensorC
|
|
||||||
>
|
|
||||||
CUTLASS_DEVICE void
|
|
||||||
mma(MainloopPipeline pipeline,
|
|
||||||
PipelineState smem_pipe_read,
|
|
||||||
FrgTensorC& accum,
|
|
||||||
int k_tile_count,
|
|
||||||
int thread_idx,
|
|
||||||
TensorStorage& shared_tensors,
|
|
||||||
Params const& mainloop_params) {
|
|
||||||
|
|
||||||
|
|
||||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
|
||||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
|
||||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
|
||||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
|
||||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
|
||||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
|
||||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
|
||||||
|
|
||||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
|
||||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
|
||||||
|
|
||||||
// Block scaling
|
|
||||||
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
|
||||||
Layout<
|
|
||||||
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
|
||||||
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
|
||||||
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
|
||||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
|
||||||
|
|
||||||
//
|
|
||||||
// Define C accumulators and A/B partitioning
|
|
||||||
//
|
|
||||||
|
|
||||||
// Layout of warp group to thread mapping
|
|
||||||
|
|
||||||
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
|
||||||
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
|
||||||
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
|
||||||
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
|
||||||
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
|
||||||
|
|
||||||
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
|
||||||
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
|
||||||
Int<NumThreadsPerWarpGroup>{});
|
|
||||||
|
|
||||||
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
|
||||||
|
|
||||||
TiledMma tiled_mma;
|
|
||||||
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
|
||||||
|
|
||||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
|
||||||
|
|
||||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
|
||||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
|
||||||
|
|
||||||
// Allocate "fragments/descriptors"
|
|
||||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
|
||||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
|
||||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
|
||||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
|
||||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
|
||||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
|
||||||
|
|
||||||
//
|
|
||||||
// PIPELINED MAIN LOOP
|
|
||||||
//
|
|
||||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
|
||||||
"ERROR : Incorrect number of MMAs in flight");
|
|
||||||
|
|
||||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
|
||||||
PipelineState smem_pipe_release = smem_pipe_read;
|
|
||||||
|
|
||||||
// Per block scale values for operand A and B
|
|
||||||
|
|
||||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
|
||||||
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
|
||||||
|
|
||||||
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
|
||||||
ElementBlockScale scale_b;
|
|
||||||
|
|
||||||
// Prologue GMMAs
|
|
||||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
|
||||||
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
||||||
|
|
||||||
GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
|
|
||||||
warpgroup_fence_operand(accumulation());
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
|
||||||
{
|
|
||||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
|
||||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
||||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
||||||
|
|
||||||
if (accumulation.prepare_if_needed()) {
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
||||||
}
|
|
||||||
|
|
||||||
int read_stage = smem_pipe_read.index();
|
|
||||||
|
|
||||||
// Load per block scale values from shared memory to registers.
|
|
||||||
scale_b = sScaleB[read_stage];
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
|
||||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
|
||||||
}
|
|
||||||
if constexpr (ScaleMsPerTile == 1) {
|
|
||||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
|
||||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
|
||||||
} else {
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
|
||||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
warpgroup_arrive();
|
|
||||||
// Unroll the K mode manually to set scale D to 1
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
||||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
|
||||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
||||||
}
|
|
||||||
warpgroup_commit_batch();
|
|
||||||
|
|
||||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
|
||||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
|
||||||
|
|
||||||
++smem_pipe_read;
|
|
||||||
}
|
|
||||||
|
|
||||||
warpgroup_fence_operand(accumulation());
|
|
||||||
// Mainloop GMMAs
|
|
||||||
k_tile_count -= prologue_mma_count;
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_NO_UNROLL
|
|
||||||
for ( ; k_tile_count > 0; --k_tile_count)
|
|
||||||
{
|
|
||||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
|
||||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
||||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Compute on k_tile
|
|
||||||
//
|
|
||||||
|
|
||||||
int read_stage = smem_pipe_read.index();
|
|
||||||
|
|
||||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
|
||||||
scale_b = sScaleB[read_stage];
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
|
||||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
|
||||||
}
|
|
||||||
if constexpr (ScaleMsPerTile == 1) {
|
|
||||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
|
||||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
|
||||||
} else {
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
|
||||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (accumulation.prepare_if_needed()) {
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
||||||
}
|
|
||||||
|
|
||||||
warpgroup_fence_operand(accumulation());
|
|
||||||
warpgroup_arrive();
|
|
||||||
// Unroll the K mode manually to set scale D to 1
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
||||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
|
||||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
||||||
}
|
|
||||||
warpgroup_commit_batch();
|
|
||||||
|
|
||||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
|
||||||
warpgroup_wait<K_PIPE_MMAS>();
|
|
||||||
warpgroup_fence_operand(accumulation());
|
|
||||||
|
|
||||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
|
||||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
|
||||||
|
|
||||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
|
||||||
|
|
||||||
// Advance smem_pipe_read and smem_pipe_release
|
|
||||||
++smem_pipe_read;
|
|
||||||
++smem_pipe_release;
|
|
||||||
}
|
|
||||||
|
|
||||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
|
||||||
|
|
||||||
warpgroup_fence_operand(accumulation());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Perform a Consumer Epilogue to release all buffers
|
|
||||||
CUTLASS_DEVICE void
|
|
||||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
|
||||||
// Prologue GMMAs
|
|
||||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
|
||||||
k_tile_count -= prologue_mma_count;
|
|
||||||
|
|
||||||
smem_pipe_release.advance(k_tile_count);
|
|
||||||
|
|
||||||
// Wait on all GMMAs to complete
|
|
||||||
warpgroup_wait<0>();
|
|
||||||
|
|
||||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
|
||||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
|
||||||
++smem_pipe_release;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
} // namespace cutlass::gemm::collective
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
@ -1,39 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
||||||
|
|
||||||
namespace cutlass::gemm {
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
|
||||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
|
||||||
// `ScaleGranularityM` indicates that scaling granularity is
|
|
||||||
// `size<0>(TileShape_MNK{})` along M.
|
|
||||||
template <int ScaleGranularityM = 0>
|
|
||||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
|
|
||||||
: KernelTmaWarpSpecializedCooperative {};
|
|
||||||
|
|
||||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
|
||||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
|
||||||
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
|
|
||||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
|
||||||
int ScaleGranularityM =
|
|
||||||
0 // `ScaleGranularityM` specifies scaling granularity along M,
|
|
||||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
|
||||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
|
||||||
>
|
|
||||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
|
||||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
|
|
||||||
KernelSchedule> {
|
|
||||||
static_assert(
|
|
||||||
cute::is_same_v<
|
|
||||||
KernelSchedule,
|
|
||||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
|
||||||
ScaleGranularityM>>,
|
|
||||||
"KernelSchedule must be one of the warp specialized policies");
|
|
||||||
};
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
} // namespace cutlass::gemm
|
|
||||||
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
namespace cutlass::gemm::collective {
|
namespace cutlass::gemm::collective {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|||||||
@ -14,9 +14,6 @@
|
|||||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
|
||||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
|
||||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
|
||||||
|
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|||||||
@ -14,9 +14,6 @@
|
|||||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
|
||||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
|
||||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
|
||||||
|
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|||||||
@ -13,27 +13,18 @@
|
|||||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
|
||||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
|
||||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
|
||||||
|
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template <typename SchedulerType, typename OutType, int GroupSizeM_,
|
// clang-format off
|
||||||
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
|
template <class OutType, int ScaleGranularityM,
|
||||||
class ClusterShape = Shape<_1, _2, _1>>
|
int ScaleGranularityN, int ScaleGranularityK,
|
||||||
|
class MmaTileShape, class ClusterShape,
|
||||||
|
class EpilogueScheduler, class MainloopScheduler>
|
||||||
struct cutlass_3x_gemm_fp8_blockwise {
|
struct cutlass_3x_gemm_fp8_blockwise {
|
||||||
using GroupSizeM = Int<GroupSizeM_>;
|
|
||||||
using GroupSizeN = Int<GroupSizeN_>;
|
|
||||||
using GroupSizeK = Int<GroupSizeK_>;
|
|
||||||
using TileSizeM = Int<TileSizeM_>;
|
|
||||||
|
|
||||||
static_assert(TileSizeM_ % GroupSizeM_ == 0,
|
|
||||||
"TileSizeM must be a multiple of GroupSizeM");
|
|
||||||
|
|
||||||
using ElementAB = cutlass::float_e4m3_t;
|
using ElementAB = cutlass::float_e4m3_t;
|
||||||
|
|
||||||
using ElementA = ElementAB;
|
using ElementA = ElementAB;
|
||||||
@ -45,52 +36,67 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
|
||||||
using ElementD = OutType;
|
using ElementD = OutType;
|
||||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
using ElementC = void;
|
using ElementC = void; // TODO: support bias
|
||||||
using StrideC = StrideD;
|
using LayoutC = LayoutD;
|
||||||
static constexpr int AlignmentC = AlignmentD;
|
static constexpr int AlignmentC = AlignmentD;
|
||||||
|
|
||||||
using ElementAccumulator = float;
|
using ElementAccumulator = float;
|
||||||
using ElementBlockScale = float;
|
|
||||||
using ElementCompute = float;
|
using ElementCompute = float;
|
||||||
|
using ElementBlockScale = float;
|
||||||
|
|
||||||
|
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
|
||||||
|
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||||
|
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
|
||||||
using ArchTag = cutlass::arch::Sm90;
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
|
|
||||||
|
|
||||||
using KernelSchedule = cutlass::gemm::
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
using ElementScalar = float;
|
||||||
GroupSizeM_>;
|
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
MmaTileShape,
|
||||||
|
ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
AlignmentC,
|
||||||
|
ElementD,
|
||||||
|
LayoutD,
|
||||||
|
AlignmentD,
|
||||||
|
EpilogueScheduler,
|
||||||
|
DefaultOperation
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
using CollectiveEpilogue =
|
ElementA,
|
||||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
cute::tuple<LayoutA, LayoutSFA>,
|
||||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
AlignmentA,
|
||||||
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
|
ElementB,
|
||||||
ElementD, StrideD, AlignmentD, EpilogueSchedule,
|
cute::tuple<LayoutB, LayoutSFB>,
|
||||||
StoreEpilogueCompute>::CollectiveOp;
|
AlignmentB,
|
||||||
|
ElementAccumulator,
|
||||||
using CollectiveMainloop =
|
MmaTileShape,
|
||||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
ClusterShape,
|
||||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
MainloopScheduler
|
||||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
>::CollectiveOp;
|
||||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
|
||||||
KernelSchedule>::CollectiveOp;
|
|
||||||
|
|
||||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||||
SchedulerType>>;
|
|
||||||
|
|
||||||
struct GemmKernel : public KernelType {};
|
struct GemmKernel : public KernelType {};
|
||||||
|
|
||||||
using StrideA = typename GemmKernel::StrideA;
|
|
||||||
using StrideB = typename GemmKernel::StrideB;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
@ -99,76 +105,54 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
|
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||||
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||||
|
using LayoutSFA = typename Gemm::LayoutSFA;
|
||||||
|
using LayoutSFB = typename Gemm::LayoutSFB;
|
||||||
|
using ScaleConfig = typename Gemm::ScaleConfig;
|
||||||
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
auto prob_shape = c3x::get_problem_shape(a, b);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
|
|
||||||
k = get<2>(prob_shape);
|
|
||||||
|
|
||||||
int64_t lda = a.stride(0);
|
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||||
int64_t ldb = b.stride(1);
|
|
||||||
int64_t ldc = out.stride(0);
|
|
||||||
|
|
||||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
StrideA a_stride;
|
||||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
StrideB b_stride;
|
||||||
using StrideC = typename Gemm::StrideC;
|
StrideC c_stride;
|
||||||
|
a_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||||
|
b_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||||
|
c_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||||
|
|
||||||
StrideA a_stride{lda, Int<1>{}, 0};
|
LayoutSFA layout_SFA =
|
||||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
LayoutSFB layout_SFB =
|
||||||
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||||
|
|
||||||
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
|
auto mainloop_args = [&](){
|
||||||
// being 1 (i.e. a row or column vector)
|
return typename GemmKernel::MainloopArguments{
|
||||||
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
a_ptr, a_stride, b_ptr, b_stride,
|
||||||
auto t_sizes = t.sizes();
|
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
||||||
return t.is_contiguous() &&
|
};
|
||||||
(t.dim() == 1 ||
|
}();
|
||||||
(t.dim() == 2 &&
|
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||||
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
|
|
||||||
// we don't have to deal with enforcing implicit layouts
|
|
||||||
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
|
|
||||||
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
|
|
||||||
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
|
|
||||||
"a_scales must be M major");
|
|
||||||
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
|
|
||||||
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
|
|
||||||
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
|
|
||||||
"b_scales must be K major");
|
|
||||||
typename GemmKernel::MainloopArguments mainloop_args{
|
|
||||||
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
|
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
|
||||||
typename GemmKernel::TileSchedulerArguments scheduler;
|
|
||||||
|
|
||||||
static constexpr bool UsesStreamKScheduler =
|
|
||||||
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
|
|
||||||
cutlass::gemm::StreamKScheduler>;
|
|
||||||
|
|
||||||
if constexpr (UsesStreamKScheduler) {
|
|
||||||
using DecompositionMode = typename cutlass::gemm::kernel::detail::
|
|
||||||
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
|
||||||
using ReductionMode = typename cutlass::gemm::kernel::detail::
|
|
||||||
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
|
||||||
|
|
||||||
scheduler.decomposition_mode = DecompositionMode::StreamK;
|
|
||||||
scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
|
||||||
}
|
|
||||||
|
|
||||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||||
epilogue_args, scheduler);
|
epilogue_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
@ -177,18 +161,12 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
auto k = a.size(1);
|
// TODO: better heuristics
|
||||||
auto n = b.size(1);
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
|
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||||
if (k > 3 * n) {
|
Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>(
|
||||||
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
|
out, a, b, a_scales, b_scales);
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
} else {
|
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
|
||||||
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
@ -32,7 +32,7 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
if (version_num >= 100) {
|
if (version_num >= 90) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
a.size(0) == a_scales.size(0) &&
|
a.size(0) == a_scales.size(0) &&
|
||||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||||
@ -41,32 +41,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||||
"b_scale_group_shape must be [128, 128].");
|
"b_scale_group_shape must be [128, 128].");
|
||||||
} else {
|
|
||||||
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
|
|
||||||
// kernel, or introducing ceil_div to the load_init() of mainloop.
|
|
||||||
using GroupShape = std::array<int64_t, 2>;
|
|
||||||
auto make_group_shape = [](torch::Tensor const& x,
|
|
||||||
torch::Tensor const& s) -> GroupShape {
|
|
||||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
|
||||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
|
||||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
|
||||||
};
|
|
||||||
|
|
||||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
|
||||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
|
||||||
|
|
||||||
// 1x128 per-token group scales for activations
|
|
||||||
// 128x128 blockwise scales for weights
|
|
||||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
|
||||||
b_scale_group_shape == GroupShape{128, 128} &&
|
|
||||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
|
||||||
b.dtype() == torch::kFloat8_e4m3fn),
|
|
||||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
|
||||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
|
||||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
|
||||||
"]\n"
|
|
||||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
|
||||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||||
|
|||||||
@ -11,8 +11,8 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
|||||||
native_w8a8_block_matmul)
|
native_w8a8_block_matmul)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
|
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
|
||||||
w8a8_block_fp8_matmul)
|
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
|
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
|
||||||
@ -98,6 +98,54 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_w8a8_block_fp8_cutlass_matmul():
|
||||||
|
# Test simple case where weight.shape % 128 != 0,
|
||||||
|
# like in DSV3 kv_a_proj_with_mqa
|
||||||
|
M = 32
|
||||||
|
N = 576
|
||||||
|
K = 7168
|
||||||
|
block_size = [128, 128]
|
||||||
|
out_dtype = torch.bfloat16
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
|
|
||||||
|
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
|
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
|
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||||
|
# Hopper requires row-major format for scales
|
||||||
|
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(
|
||||||
|
90) else Bs
|
||||||
|
|
||||||
|
A_fp8, As = per_token_group_quant_fp8(A_fp32,
|
||||||
|
block_size[1],
|
||||||
|
column_major_scales=False)
|
||||||
|
# CUTLASS uses column-major format for scales
|
||||||
|
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
|
||||||
|
A_fp32, block_size[1], column_major_scales=True)
|
||||||
|
|
||||||
|
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||||
|
out_dtype)
|
||||||
|
out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass,
|
||||||
|
block_size, out_dtype)
|
||||||
|
|
||||||
|
rel_diff = (torch.mean(
|
||||||
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||||
|
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||||
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"M,N,K,block_size,out_dtype,seed",
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||||
|
|||||||
@ -30,7 +30,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|||||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace,
|
||||||
|
should_use_deepgemm_for_fp8_linear)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||||
prepare_moe_fp8_layer_for_marlin)
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
@ -462,6 +463,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
block_sz,
|
block_sz,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||||
|
if (self.block_quant and current_platform.is_device_capability(90)
|
||||||
|
and self.cutlass_block_fp8_supported
|
||||||
|
and not should_use_deepgemm_for_fp8_linear(
|
||||||
|
torch.bfloat16, layer.weight)):
|
||||||
|
layer.weight_scale_inv = Parameter(
|
||||||
|
layer.weight_scale_inv.data.T.contiguous(),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@ -40,11 +40,14 @@ def cutlass_scaled_mm(
|
|||||||
block_size: list[int],
|
block_size: list[int],
|
||||||
output_dtype: torch.dtype = torch.float16,
|
output_dtype: torch.dtype = torch.float16,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return ops.cutlass_scaled_mm(A,
|
return ops.cutlass_scaled_mm(
|
||||||
B.T,
|
A,
|
||||||
out_dtype=output_dtype,
|
B.T,
|
||||||
scale_a=As,
|
out_dtype=output_dtype,
|
||||||
scale_b=Bs.T)
|
scale_a=As,
|
||||||
|
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||||
|
scale_b=Bs if block_size is not None
|
||||||
|
and current_platform.is_device_capability(90) else Bs.T)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||||
@ -152,35 +155,32 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
output += bias
|
output += bias
|
||||||
return output.to(dtype=output_dtype).view(*output_shape)
|
return output.to(dtype=output_dtype).view(*output_shape)
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
if current_platform.has_device_capability(100):
|
|
||||||
|
|
||||||
use_cutlass = cutlass_block_fp8_supported and (
|
|
||||||
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
|
|
||||||
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
|
|
||||||
else:
|
|
||||||
# TODO: update this after switching to public sm90 block scale gemm
|
|
||||||
# as it also supports weight.shape % 128 != 0
|
|
||||||
use_cutlass = cutlass_block_fp8_supported and (
|
|
||||||
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
|
||||||
else:
|
|
||||||
use_cutlass = False
|
|
||||||
|
|
||||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||||
use_cutlass, use_aiter_and_is_supported)
|
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||||
if use_cutlass:
|
if cutlass_block_fp8_supported:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
num_pad = 0
|
||||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
if current_platform.is_device_capability(90):
|
||||||
|
# pad first dimension to be divisible by 4 due to
|
||||||
|
# cutlass blockwise gemm limitation for hopper
|
||||||
|
num_pad = 4 - (input_2d.shape[0] % 4)
|
||||||
|
if num_pad > 0:
|
||||||
|
input_2d = torch.nn.functional.pad(input_2d,
|
||||||
|
(0, 0, 0, num_pad),
|
||||||
|
"constant", 0)
|
||||||
|
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||||
|
block_size[1],
|
||||||
|
column_major_scales=True)
|
||||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||||
block_size, input.dtype)
|
block_size, input.dtype)
|
||||||
|
if num_pad > 0:
|
||||||
|
output = output[:-num_pad]
|
||||||
else:
|
else:
|
||||||
if use_aiter_and_is_supported:
|
if use_aiter_and_is_supported:
|
||||||
q_input, x_scale = aiter_per1x128_quant(
|
q_input, x_scale = aiter_per1x128_quant(
|
||||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||||
else:
|
else:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
input_2d, block_size[1], column_major_scales=False)
|
||||||
|
|
||||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||||
block_size, input.dtype)
|
block_size, input.dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user