diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f948157c2b575..1115fc606b055 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -214,6 +214,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, @@ -381,8 +385,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) # Check if any layer uses sinks (requires TRTLLM attention) - has_sinks = self.global_hyperparameters.has_sinks - prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_prefill_tokens, @@ -390,7 +392,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_dtype, self.q_data_type, is_prefill=True, - has_sinks=has_sinks) + has_sinks=self.has_sinks) decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_decode_tokens, @@ -398,7 +400,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_dtype, self.q_data_type, is_prefill=False, - has_sinks=has_sinks) + has_sinks=self.has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -433,9 +435,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.head_dim, self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) @@ -472,10 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.head_dim, self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) @@ -525,10 +526,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, )