mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:37:44 +08:00
[Bugfix] Machete garbage results for some models (large K dim) (#9212)
This commit is contained in:
parent
ce00231a8b
commit
a64e7b9407
@ -591,24 +591,27 @@ struct MacheteCollectiveMma {
|
|||||||
tma_load_b = make_tma_copy_B(
|
tma_load_b = make_tma_copy_B(
|
||||||
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
|
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
|
||||||
|
|
||||||
|
int32_t scale_k =
|
||||||
|
(ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0;
|
||||||
|
int32_t group_size = (ModeHasScales) ? args.group_size : 0;
|
||||||
|
|
||||||
if constexpr (ModeHasScales) {
|
if constexpr (ModeHasScales) {
|
||||||
tma_load_scale = make_tma_copy_scale(make_logical_tensor(
|
tma_load_scale = make_tma_copy_scale(
|
||||||
args.ptr_S, make_shape(M, args.group_size, L), args.dS));
|
make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS));
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (KernelConversionMode ==
|
if constexpr (KernelConversionMode ==
|
||||||
ConversionMode::ConvertAndScaleWithZero) {
|
ConversionMode::ConvertAndScaleWithZero) {
|
||||||
tma_load_zero = make_tma_copy_zero(make_logical_tensor(
|
tma_load_zero = make_tma_copy_zero(
|
||||||
args.ptr_Z, make_shape(M, args.group_size, L), args.dS));
|
make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS));
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
|
||||||
return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0};
|
KernelConversionMode == ConversionMode::ConvertAndScale ||
|
||||||
} else if constexpr (ModeHasScales) {
|
KernelConversionMode ==
|
||||||
auto scale_k = (K + args.group_size - 1) / args.group_size;
|
ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
|
||||||
return {tma_load_a, tma_load_b, tma_load_scale,
|
return {tma_load_a, tma_load_b, tma_load_scale,
|
||||||
tma_load_zero, scale_k, args.group_size};
|
tma_load_zero, scale_k, group_size};
|
||||||
} else {
|
} else {
|
||||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
|
||||||
"Conversion mode not handled in to_underlying_arguments.");
|
"Conversion mode not handled in to_underlying_arguments.");
|
||||||
|
|||||||
@ -24,13 +24,14 @@ MNK_SHAPES = [
|
|||||||
(1, 128, 128),
|
(1, 128, 128),
|
||||||
(1, 512, 1024),
|
(1, 512, 1024),
|
||||||
(1, 4096, 4096),
|
(1, 4096, 4096),
|
||||||
|
(1, 8192, 28672),
|
||||||
(13, 8192, 4096),
|
(13, 8192, 4096),
|
||||||
(26, 4096, 8192),
|
(26, 4096, 8192),
|
||||||
(1, 4096, 4096),
|
(64, 4096, 4096),
|
||||||
|
(64, 8192, 28672),
|
||||||
(257, 128, 4096),
|
(257, 128, 4096),
|
||||||
(257, 4224, 4160),
|
(257, 4224, 4160),
|
||||||
(257, 4096, 4096),
|
(257, 4096, 4096),
|
||||||
(64, 4096, 4096),
|
|
||||||
(1024, 4096, 8192),
|
(1024, 4096, 8192),
|
||||||
(1024, 8192, 4096),
|
(1024, 8192, 4096),
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user