From fedb75fa2790403b90ec6dc926fef9c6c5ccb7a6 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Date: Wed, 17 Sep 2025 18:06:38 -0400 Subject: [PATCH] [Bugfix][B200] Fix `cutlass_mla` hang (#24966) Signed-off-by: Alexander Matveev Co-authored-by: Michael Goin --- csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index 95e32559cd540..fbbc2e588c326 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -133,6 +133,14 @@ public: // printf(" sm_count = %d\n", sm_count); int max_splits = ceil_div(K, 128); max_splits = min(16, max_splits); + + // TODO: This avoids a hang when the batch size larger than 1 and + // there is more than 4 kv_splits. + // Discuss with NVIDIA how this can be fixed. + if (B > 1) { + max_splits = min(2, max_splits); + } + // printf(" max_splits = %d\n", max_splits); int sms_per_batch = max(1, sm_count / B); // printf(" sms_per_batch = %d\n", sms_per_batch);