mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 18:23:12 +08:00
[Kernels][FI] Skip trtllm attention when num_kv_heads=1 (#30842)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
4c054d89aa
commit
a100152288
@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
|
||||
|
||||
def test_trtllm_attention_rejects_num_kv_heads_1() -> None:
|
||||
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
|
||||
|
||||
When num_kv_heads=1 (MQA), the KV cache strides become degenerate
|
||||
(stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled
|
||||
to fail because TMA descriptors cannot handle degenerate 4D tensors with
|
||||
singleton dimensions.
|
||||
|
||||
This test verifies that can_use_trtllm_attention returns False for
|
||||
num_kv_heads=1 configurations.
|
||||
"""
|
||||
from vllm.utils.flashinfer import can_use_trtllm_attention
|
||||
|
||||
# num_kv_heads=1 should be rejected
|
||||
assert not can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1), (
|
||||
"can_use_trtllm_attention should return False for num_kv_heads=1"
|
||||
)
|
||||
assert not can_use_trtllm_attention(num_qo_heads=32, num_kv_heads=1), (
|
||||
"can_use_trtllm_attention should return False for num_kv_heads=1"
|
||||
)
|
||||
|
||||
# num_kv_heads > 1 should be accepted (if platform supports it)
|
||||
# Note: This may return False on non-Blackwell platforms, which is fine
|
||||
result_kv8 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=8)
|
||||
result_kv1 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1)
|
||||
|
||||
# Even if platform doesn't support TRTLLM, num_kv_heads=1 should never
|
||||
# return True when num_kv_heads > 1 returns True
|
||||
if result_kv8:
|
||||
assert not result_kv1, (
|
||||
"If TRTLLM is supported for num_kv_heads=8, "
|
||||
"it must be rejected for num_kv_heads=1"
|
||||
)
|
||||
|
||||
@ -305,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||
if force_use_trtllm_attention() is False:
|
||||
return False
|
||||
has_trtllm = supports_trtllm_attention()
|
||||
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
|
||||
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
|
||||
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
|
||||
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
|
||||
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
|
||||
# See: https://fburl.com/352mrydz
|
||||
if has_trtllm and num_kv_heads == 1:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention does not support num_kv_heads=1. "
|
||||
"This configuration causes TMA descriptor building to fail due to "
|
||||
"degenerate tensor strides. Falling back to FlashInfer attention."
|
||||
)
|
||||
return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1)
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
@ -355,6 +366,15 @@ def use_trtllm_attention(
|
||||
)
|
||||
return False
|
||||
|
||||
# num_kv_heads=1 is not supported
|
||||
if num_kv_heads == 1:
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention does not support num_kv_heads=1, "
|
||||
"but --attention-config.use_trtllm_attention is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
if has_spec and not is_prefill:
|
||||
# Speculative decoding requires TRTLLM attention for decodes
|
||||
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user