From db77f9b3a2badf9ff8ab0199e57f963987dfa20c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 26 Sep 2025 10:24:13 -0700 Subject: [PATCH] potential hang fix Signed-off-by: Lucas Wilkinson --- .../kernel/sm100_fmha_mla_tma_warpspecialized.hpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index 2cbc2379579eb..321af8058f3f9 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -582,7 +582,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto problem_shape = params.problem_shape; auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + auto seqlen = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (seqlen == 0) continue; + get<1>(problem_shape) = seqlen; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } @@ -607,7 +609,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto problem_shape = params.problem_shape; auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + auto seqlen = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (seqlen == 0) continue; + get<1>(problem_shape) = seqlen; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } @@ -636,7 +640,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto problem_shape = params.problem_shape; auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + auto seqlen = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (seqlen == 0) continue; + get<1>(problem_shape) = seqlen; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; }