This commit is contained in:
Cody Yu 2024-08-27 00:16:31 -07:00 committed by GitHub
parent 64cc644425
commit 9606c7197d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,8 +113,7 @@ class FlashInferState(AttentionState):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
@ -172,8 +171,7 @@ class FlashInferState(AttentionState):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,