From a100152288c8ec50336aea842f0b3d8e36624024 Mon Sep 17 00:00:00 2001 From: "Ye (Charlotte) Qi" Date: Wed, 17 Dec 2025 01:54:21 -0800 Subject: [PATCH] [Kernels][FI] Skip trtllm attention when num_kv_heads=1 (#30842) Signed-off-by: Ye (Charlotte) Qi --- .../test_flashinfer_trtllm_attention.py | 35 +++++++++++++++++++ vllm/utils/flashinfer.py | 22 +++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 06a7085a82ba0..220d827b9d5fa 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline( torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), f"{torch.max(torch.abs(output - output_trtllm))}", ) + + +def test_trtllm_attention_rejects_num_kv_heads_1() -> None: + """Test that TRTLLM attention correctly rejects num_kv_heads=1. + + When num_kv_heads=1 (MQA), the KV cache strides become degenerate + (stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled + to fail because TMA descriptors cannot handle degenerate 4D tensors with + singleton dimensions. + + This test verifies that can_use_trtllm_attention returns False for + num_kv_heads=1 configurations. + """ + from vllm.utils.flashinfer import can_use_trtllm_attention + + # num_kv_heads=1 should be rejected + assert not can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1), ( + "can_use_trtllm_attention should return False for num_kv_heads=1" + ) + assert not can_use_trtllm_attention(num_qo_heads=32, num_kv_heads=1), ( + "can_use_trtllm_attention should return False for num_kv_heads=1" + ) + + # num_kv_heads > 1 should be accepted (if platform supports it) + # Note: This may return False on non-Blackwell platforms, which is fine + result_kv8 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=8) + result_kv1 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1) + + # Even if platform doesn't support TRTLLM, num_kv_heads=1 should never + # return True when num_kv_heads > 1 returns True + if result_kv8: + assert not result_kv1, ( + "If TRTLLM is supported for num_kv_heads=8, " + "it must be rejected for num_kv_heads=1" + ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 1c2710be3173b..6bbe02348eaf1 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -305,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: if force_use_trtllm_attention() is False: return False has_trtllm = supports_trtllm_attention() - return has_trtllm and (num_qo_heads % num_kv_heads == 0) + # num_kv_heads=1 is not supported due to TMA descriptor building limitations. + # When num_kv_heads=1, the KV cache strides become degenerate (stride_heads == + # stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because + # TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions. + # See: https://fburl.com/352mrydz + if has_trtllm and num_kv_heads == 1: + logger.warning_once( + "TRTLLM attention does not support num_kv_heads=1. " + "This configuration causes TMA descriptor building to fail due to " + "degenerate tensor strides. Falling back to FlashInfer attention." + ) + return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1) def use_trtllm_attention( @@ -355,6 +366,15 @@ def use_trtllm_attention( ) return False + # num_kv_heads=1 is not supported + if num_kv_heads == 1: + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention does not support num_kv_heads=1, " + "but --attention-config.use_trtllm_attention is set to 1" + ) + return False + if has_spec and not is_prefill: # Speculative decoding requires TRTLLM attention for decodes logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")