diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index 2cbc2379579eb..321af8058f3f9 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -582,7 +582,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto problem_shape = params.problem_shape; auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + auto seqlen = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (seqlen == 0) continue; + get<1>(problem_shape) = seqlen; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } @@ -607,7 +609,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto problem_shape = params.problem_shape; auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + auto seqlen = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (seqlen == 0) continue; + get<1>(problem_shape) = seqlen; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } @@ -636,7 +640,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto problem_shape = params.problem_shape; auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + auto seqlen = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (seqlen == 0) continue; + get<1>(problem_shape) = seqlen; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; }