[BugFix][DP/EP] Fix CUTLASS MLA hang under load (#26026)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Lucas Wilkinson 2025-10-01 15:30:00 -04:00 committed by simon-mo
parent e4beabd2c8
commit ebce361c07

View File

@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_page_table( load_page_table(
blk_coord, blk_coord,
problem_shape, problem_shape,
params.mainloop, params.mainloop,
shared_storage.tensors, shared_storage.tensors,
pipeline_page_table, pipeline_pt_producer_state, pipeline_page_table, pipeline_pt_producer_state,
local_split_kv local_split_kv
); );
} }
} }
@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_cpasync( load_cpasync(
blk_coord, blk_coord,
@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
params.mainloop_params, params.mainloop_params,
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv, local_split_kv,
/* must be shared pipe */ /* must be shared pipe */
pipeline_page_table, pipeline_pt_consumer_state pipeline_page_table, pipeline_pt_consumer_state
); );
@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_tma</* paged= */ true>( load_tma</* paged= */ true>(
blk_coord, blk_coord,
@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv local_split_kv
); );
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
} }
@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_tma<false>( load_tma<false>(
blk_coord, blk_coord,
@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv local_split_kv
); );
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
} }
@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
mma(blk_coord, mma(blk_coord,
problem_shape, problem_shape,
@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_producer_state, pipeline_mma_s, pipeline_mma_s_producer_state,
pipeline_p_mma, pipeline_p_mma_consumer_state, pipeline_p_mma, pipeline_p_mma_consumer_state,
pipeline_mma_o, pipeline_mma_o_producer_state, pipeline_mma_o, pipeline_mma_o_producer_state,
local_split_kv local_split_kv
); );
} }
} }
@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto split_kv = params.split_kv; auto split_kv = params.split_kv;
auto local_split_kv = split_kv; auto local_split_kv = split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
compute( compute(
blk_coord, blk_coord,
@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_consumer_state, pipeline_mma_s, pipeline_mma_s_consumer_state,
pipeline_p_mma, pipeline_p_mma_producer_state, pipeline_p_mma, pipeline_p_mma_producer_state,
pipeline_mma_o, pipeline_mma_o_consumer_state, pipeline_mma_o, pipeline_mma_o_consumer_state,
local_split_kv local_split_kv
); );
} }
@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
cutlass::arch::NamedBarrier( cutlass::arch::NamedBarrier(
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
kNamedBarrierEpilogue kNamedBarrierEpilogue
).arrive(); ).arrive_and_wait();
return; return;
} }