From 0c9a5258f905ff3b03019f9134914ab90dbdac01 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 19 Oct 2024 02:55:48 +0200 Subject: [PATCH] [Kernel] Add env variable to force flashinfer backend to enable tensor cores (#9497) Signed-off-by: Thomas Parnell Co-authored-by: Chih-Chieh Yang Co-authored-by: Cody Yu --- vllm/attention/backends/flashinfer.py | 7 +++++-- vllm/envs.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index dd9a0fb9d94df..1dd2a21fdb51a 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -17,6 +17,7 @@ except ImportError: import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, @@ -124,7 +125,8 @@ class FlashInferState(AttentionState): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = num_qo_heads // num_kv_heads > 4 + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD", @@ -183,7 +185,8 @@ class FlashInferState(AttentionState): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = num_qo_heads // num_kv_heads > 4 + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer, diff --git a/vllm/envs.py b/vllm/envs.py index 2396e87e20c39..385db82d89249 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: bool = False VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False + VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -286,6 +287,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { "VLLM_USE_FLASHINFER_SAMPLER": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), + # If set, vllm will force flashinfer to use tensor cores; + # otherwise will use heuristic based on model architecture. + "VLLM_FLASHINFER_FORCE_TENSOR_CORES": + lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), + # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),