mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 04:35:47 +08:00
[Bug] Fix Cutlass Scaled MM Compilation Error (#24887)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
aae725af7c
commit
e757a629e7
@ -146,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
using ElementBlockScale = typename Gemm::ElementBlockScale;
|
||||||
|
|
||||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
|
|
||||||
@ -166,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
|
||||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
|
||||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||||
|
|
||||||
auto mainloop_args = [&](){
|
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||||
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
mainloop_args.layout_SFA = layout_SFA;
|
||||||
if (swap_ab) {
|
mainloop_args.layout_SFB = layout_SFB;
|
||||||
return typename GemmKernel::MainloopArguments{
|
if (swap_ab) {
|
||||||
b_ptr, b_stride, a_ptr, a_stride,
|
mainloop_args.ptr_A = b_ptr;
|
||||||
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
|
mainloop_args.dA = b_stride;
|
||||||
};
|
mainloop_args.ptr_B = a_ptr;
|
||||||
}
|
mainloop_args.dB = a_stride;
|
||||||
else {
|
mainloop_args.ptr_SFA = b_scales_ptr;
|
||||||
return typename GemmKernel::MainloopArguments{
|
mainloop_args.ptr_SFB = a_scales_ptr;
|
||||||
a_ptr, a_stride, b_ptr, b_stride,
|
} else {
|
||||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
mainloop_args.ptr_A = a_ptr;
|
||||||
};
|
mainloop_args.dA = a_stride;
|
||||||
}
|
mainloop_args.ptr_B = b_ptr;
|
||||||
}();
|
mainloop_args.dB = b_stride;
|
||||||
|
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||||
|
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||||
|
}
|
||||||
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
|||||||
@ -125,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
using ElementBlockScale = typename Gemm::ElementBlockScale;
|
||||||
|
|
||||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
|
|
||||||
@ -143,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
LayoutSFB layout_SFB =
|
LayoutSFB layout_SFB =
|
||||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
|
||||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||||
|
|
||||||
auto mainloop_args = [&](){
|
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||||
return typename GemmKernel::MainloopArguments{
|
mainloop_args.ptr_A = a_ptr;
|
||||||
a_ptr, a_stride, b_ptr, b_stride,
|
mainloop_args.dA = a_stride;
|
||||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
mainloop_args.ptr_B = b_ptr;
|
||||||
};
|
mainloop_args.dB = b_stride;
|
||||||
}();
|
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||||
|
mainloop_args.layout_SFA = layout_SFA;
|
||||||
|
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||||
|
mainloop_args.layout_SFB = layout_SFB;
|
||||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
|||||||
@ -115,6 +115,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
using ElementBlockScale = typename Gemm::ElementBlockScale;
|
||||||
|
|
||||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
|
|
||||||
@ -135,17 +136,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
LayoutSFB layout_SFB =
|
LayoutSFB layout_SFB =
|
||||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
|
||||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||||
|
|
||||||
auto mainloop_args = [&](){
|
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||||
return typename GemmKernel::MainloopArguments{
|
mainloop_args.ptr_A = a_ptr;
|
||||||
a_ptr, a_stride, b_ptr, b_stride,
|
mainloop_args.dA = a_stride;
|
||||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
mainloop_args.ptr_B = b_ptr;
|
||||||
};
|
mainloop_args.dB = b_stride;
|
||||||
}();
|
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||||
|
mainloop_args.layout_SFA = layout_SFA;
|
||||||
|
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||||
|
mainloop_args.layout_SFB = layout_SFB;
|
||||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user