From 11eddf02f0234f79435d747f2d3dce117ab39aa1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Aug 2025 03:45:04 -0700 Subject: [PATCH] [FlashInfer] Cache hyper params in metadata builder (#23732) Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flashinfer.py | 30 ++++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f948157c2b575..1115fc606b055 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -214,6 +214,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, @@ -381,8 +385,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) # Check if any layer uses sinks (requires TRTLLM attention) - has_sinks = self.global_hyperparameters.has_sinks - prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_prefill_tokens, @@ -390,7 +392,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_dtype, self.q_data_type, is_prefill=True, - has_sinks=has_sinks) + has_sinks=self.has_sinks) decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_decode_tokens, @@ -398,7 +400,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_dtype, self.q_data_type, is_prefill=False, - has_sinks=has_sinks) + has_sinks=self.has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -433,9 +435,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.head_dim, self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) @@ -472,10 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.head_dim, self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) @@ -525,10 +526,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, )