diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index fbbc2e588c32..297d94dcc063 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -135,10 +135,10 @@ public: max_splits = min(16, max_splits); // TODO: This avoids a hang when the batch size larger than 1 and - // there is more than 4 kv_splits. + // there is more than 1 kv_splits. // Discuss with NVIDIA how this can be fixed. if (B > 1) { - max_splits = min(2, max_splits); + max_splits = min(1, max_splits); } // printf(" max_splits = %d\n", max_splits);