[Kernels][FI] Skip trtllm attention when num_kv_heads=1 (#30842)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi 2025-12-17 01:54:21 -08:00 committed by GitHub
parent 4c054d89aa
commit a100152288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 1 deletions

View File

@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}",
)
def test_trtllm_attention_rejects_num_kv_heads_1() -> None:
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
When num_kv_heads=1 (MQA), the KV cache strides become degenerate
(stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled
to fail because TMA descriptors cannot handle degenerate 4D tensors with
singleton dimensions.
This test verifies that can_use_trtllm_attention returns False for
num_kv_heads=1 configurations.
"""
from vllm.utils.flashinfer import can_use_trtllm_attention
# num_kv_heads=1 should be rejected
assert not can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
assert not can_use_trtllm_attention(num_qo_heads=32, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
# num_kv_heads > 1 should be accepted (if platform supports it)
# Note: This may return False on non-Blackwell platforms, which is fine
result_kv8 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=8)
result_kv1 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1)
# Even if platform doesn't support TRTLLM, num_kv_heads=1 should never
# return True when num_kv_heads > 1 returns True
if result_kv8:
assert not result_kv1, (
"If TRTLLM is supported for num_kv_heads=8, "
"it must be rejected for num_kv_heads=1"
)

View File

@ -305,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
if force_use_trtllm_attention() is False:
return False
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if has_trtllm and num_kv_heads == 1:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1)
def use_trtllm_attention(
@ -355,6 +366,15 @@ def use_trtllm_attention(
)
return False
# num_kv_heads=1 is not supported
if num_kv_heads == 1:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1, "
"but --attention-config.use_trtllm_attention is set to 1"
)
return False
if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")