Fix trtllm-gen attention env and add attention sink (#22378)

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Lain <fusiyuan2000@hotmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Lain 2025-08-06 18:07:41 -07:00 committed by GitHub
parent 5c7cc33f4d
commit 9a3835aaa9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 21 additions and 28 deletions

View File

@ -152,8 +152,7 @@ if TYPE_CHECKING:
VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
@ -1043,13 +1042,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
# If set to 1, use the TRTLLM Context Attention backend in flashinfer.
"VLLM_USE_TRTLLM_CONTEXT_ATTENTION":
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))),
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))),
# If set to 1, use the TRTLLM attention backend in flashinfer.
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.

View File

@ -70,9 +70,8 @@ class OAIAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size()
attention_sink_dtype = (
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
else torch.bfloat16)
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=attention_sink_dtype,

View File

@ -159,7 +159,7 @@ def use_trtllm_attention(
# Check if the dimensions are supported by TRTLLM decode attention
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
or num_qo_heads % num_kv_heads != 0):
return False
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
@ -169,10 +169,10 @@ def use_trtllm_attention(
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = (env_value == "0")
if not no_use_trtllm:
use_trtllm = (env_value == "1")
if use_trtllm:
logger.info_once("Using TRTLLM attention.")
return not no_use_trtllm
return use_trtllm
else:
# Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072

View File

@ -215,6 +215,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
# TODO: discard this for trtllm-gen backend
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
@ -523,16 +524,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim = self.kv_cache_spec.head_size
# currently prefill trtllm attention does not support fp8 kv cache
# trtllm may not support sliding window
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
and not cache_dtype.startswith("fp8")
and use_trtllm_attention(
prefill_use_trtllm = use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
and use_trtllm_attention(
num_qo_heads, num_kv_heads, head_dim)
decode_use_trtllm = use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
num_qo_heads, num_kv_heads, head_dim)
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
@ -793,6 +790,8 @@ class FlashInferImpl(AttentionImpl):
batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
window_left=window_left,
sinks=self.sinks,
out=output[num_decode_tokens:],
)
@ -839,6 +838,8 @@ class FlashInferImpl(AttentionImpl):
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
window_left=window_left,
sinks=self.sinks,
out=output[:num_decode_tokens],
)
return output_padded

View File

@ -254,8 +254,7 @@ def get_kv_cache_layout():
# Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
if envs.VLLM_USE_TRTLLM_ATTENTION:
cache_layout = "HND"
else:
cache_layout = get_kv_connector_cache_layout()
@ -333,8 +332,7 @@ def infer_global_hyperparameters(
global_params = param_sets[0]
# trtllm attention doesn't need global hyper params so disable the check
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
if not envs.VLLM_USE_TRTLLM_ATTENTION:
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(