From 9a3835aaa9006c0d53628f278319642774d88fbe Mon Sep 17 00:00:00 2001 From: Lain Date: Wed, 6 Aug 2025 18:07:41 -0700 Subject: [PATCH] Fix trtllm-gen attention env and add attention sink (#22378) Signed-off-by: Siyuan Fu Signed-off-by: Lain Signed-off-by: Yongye Zhu Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin Co-authored-by: Yongye Zhu --- vllm/envs.py | 13 ++++--------- vllm/model_executor/models/gpt_oss.py | 5 ++--- vllm/utils/flashinfer.py | 8 ++++---- vllm/v1/attention/backends/flashinfer.py | 17 +++++++++-------- vllm/v1/attention/backends/utils.py | 6 ++---- 5 files changed, 21 insertions(+), 28 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 8a3eb8e509f7d..d9ebf59c1ae16 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -152,8 +152,7 @@ if TYPE_CHECKING: VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False - VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False - VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False + VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False @@ -1043,13 +1042,9 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - # If set to 1, use the TRTLLM Context Attention backend in flashinfer. - "VLLM_USE_TRTLLM_CONTEXT_ATTENTION": - lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))), - - # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. - "VLLM_USE_TRTLLM_DECODE_ATTENTION": - lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))), + # If set to 1, use the TRTLLM attention backend in flashinfer. + "VLLM_USE_TRTLLM_ATTENTION": + lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index c37c4e9610054..feb323a04524b 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -70,9 +70,8 @@ class OAIAttention(nn.Module): tp_size = get_tensor_model_parallel_world_size() - attention_sink_dtype = ( - torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION - or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16) + attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION + else torch.bfloat16) self.sinks = torch.nn.Parameter( torch.empty(config.num_attention_heads // tp_size, dtype=attention_sink_dtype, diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index cce1aefaf9b02..32c52612ca16f 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -159,7 +159,7 @@ def use_trtllm_attention( # Check if the dimensions are supported by TRTLLM decode attention if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None - or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): + or num_qo_heads % num_kv_heads != 0): return False env_value = envs.VLLM_USE_TRTLLM_ATTENTION @@ -169,10 +169,10 @@ def use_trtllm_attention( # Making the conditional check for zero because # the path is automatically enabled if the batch size condition # is satisfied. - no_use_trtllm = (env_value == "0") - if not no_use_trtllm: + use_trtllm = (env_value == "1") + if use_trtllm: logger.info_once("Using TRTLLM attention.") - return not no_use_trtllm + return use_trtllm else: # Environment variable not set - use auto-detection use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 061bd5f1d277a..1fcb190286329 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -215,6 +215,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers + # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) @@ -523,16 +524,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): head_dim = self.kv_cache_spec.head_size # currently prefill trtllm attention does not support fp8 kv cache - # trtllm may not support sliding window - prefill_use_trtllm = (self.global_hyperparameters.window_left == -1 - and not cache_dtype.startswith("fp8") - and use_trtllm_attention( + prefill_use_trtllm = use_trtllm_attention( num_prefill_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim)) - decode_use_trtllm = (self.global_hyperparameters.window_left == -1 - and use_trtllm_attention( + num_qo_heads, num_kv_heads, head_dim) + decode_use_trtllm = use_trtllm_attention( num_decode_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim)) + num_qo_heads, num_kv_heads, head_dim) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -793,6 +790,8 @@ class FlashInferImpl(AttentionImpl): batch_size=attn_metadata.num_prefills, cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, + window_left=window_left, + sinks=self.sinks, out=output[num_decode_tokens:], ) @@ -839,6 +838,8 @@ class FlashInferImpl(AttentionImpl): max_seq_len=attn_metadata.max_seq_len, bmm1_scale=layer._k_scale_float * self.scale, bmm2_scale=layer._v_scale_float, + window_left=window_left, + sinks=self.sinks, out=output[:num_decode_tokens], ) return output_padded diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f521d94331b5e..770c14572ff2b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -254,8 +254,7 @@ def get_kv_cache_layout(): # Override with format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT if cache_layout is None: - if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION - or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION): + if envs.VLLM_USE_TRTLLM_ATTENTION: cache_layout = "HND" else: cache_layout = get_kv_connector_cache_layout() @@ -333,8 +332,7 @@ def infer_global_hyperparameters( global_params = param_sets[0] # trtllm attention doesn't need global hyper params so disable the check - if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION - and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION): + if not envs.VLLM_USE_TRTLLM_ATTENTION: for params in param_sets: if params.window_left != global_params.window_left: raise ValueError(