diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c56e721dff8cf..44f95c7686863 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -36,6 +36,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills) +# yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 @@ -541,12 +542,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): if cache_dtype.startswith("fp8") and enable_fusion: q_dtype = kv_cache_dtype - prefill_use_trtllm = use_trtllm_attention( - num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len, - cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks) - decode_use_trtllm = use_trtllm_attention( - num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len, - cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks) + prefill_use_trtllm = use_trtllm_attention(num_qo_heads, + num_kv_heads, + num_prefill_tokens, + max_seq_len, + cache_dtype, + q_dtype, + is_prefill=True, + has_sinks=has_sinks) + decode_use_trtllm = use_trtllm_attention(num_qo_heads, + num_kv_heads, + num_decode_tokens, + max_seq_len, + cache_dtype, + q_dtype, + is_prefill=False, + has_sinks=has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -654,19 +665,18 @@ class FlashInferImpl(AttentionImpl): raise ValueError( "Sinks must have the same number of heads as the number of " f"heads in the layer. Expected {num_heads}, but got " - f"{sinks.shape[0]}." - ) + f"{sinks.shape[0]}.") self.sinks = sinks - self.support_trtllm_attn = (supports_trtllm_attention() and - num_heads % num_kv_heads == 0) + self.support_trtllm_attn = (supports_trtllm_attention() + and num_heads % num_kv_heads == 0) self.bmm1_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, group_shape: GroupShape): - supported_quant_type = (dtype == FP8_DTYPE and static and - group_shape == GroupShape.PER_TENSOR) + supported_quant_type = (dtype == FP8_DTYPE and static + and group_shape == GroupShape.PER_TENSOR) return (self.support_trtllm_attn and self.kv_cache_dtype.startswith("fp8") and supported_quant_type) @@ -731,7 +741,8 @@ class FlashInferImpl(AttentionImpl): # Insert FP8 quant for query num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( - query.reshape((num_tokens, num_heads * head_size)).contiguous(), + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size))