From 9606c7197df073e373ab9e716a62dd4c35398865 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 27 Aug 2024 00:16:31 -0700 Subject: [PATCH] Revert #7509 (#7887) --- vllm/attention/backends/flashinfer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ce7a7198dc400..a8d76b79ff204 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -113,8 +113,7 @@ class FlashInferState(AttentionState): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ - (1, 2, 4, 8) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD", @@ -172,8 +171,7 @@ class FlashInferState(AttentionState): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ - (1, 2, 4, 8) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer,