diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ce7a7198dc400..a8d76b79ff204 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -113,8 +113,7 @@ class FlashInferState(AttentionState): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ - (1, 2, 4, 8) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD", @@ -172,8 +171,7 @@ class FlashInferState(AttentionState): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ - (1, 2, 4, 8) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer,