mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[Kernel] Add env variable to force flashinfer backend to enable tensor cores (#9497)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
d11bf435a0
commit
0c9a5258f9
@ -17,6 +17,7 @@ except ImportError:
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
@ -124,7 +125,8 @@ class FlashInferState(AttentionState):
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = num_qo_heads // num_kv_heads > 4
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
"NHD",
|
||||
@ -183,7 +185,8 @@ class FlashInferState(AttentionState):
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = num_qo_heads // num_kv_heads > 4
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._graph_decode_wrapper = \
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||
self._graph_decode_workspace_buffer, _indptr_buffer,
|
||||
|
||||
@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_SAMPLER: bool = False
|
||||
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
|
||||
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
@ -286,6 +287,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_FLASHINFER_SAMPLER":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
|
||||
|
||||
# If set, vllm will force flashinfer to use tensor cores;
|
||||
# otherwise will use heuristic based on model architecture.
|
||||
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
|
||||
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
|
||||
|
||||
# Pipeline stage partition strategy
|
||||
"VLLM_PP_LAYER_PARTITION":
|
||||
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user