mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:54:59 +08:00
refactor: abstract graph mode support into platform interface (#25161)
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
parent
4cf71cc88a
commit
b6f01bd9a7
@ -503,7 +503,7 @@ class VllmConfig:
|
|||||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
self.compilation_config.custom_ops.append("+rms_norm")
|
self.compilation_config.custom_ops.append("+rms_norm")
|
||||||
|
|
||||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
if current_platform.support_static_graph_mode():
|
||||||
# if cudagraph_mode is not explicitly set by users, set default
|
# if cudagraph_mode is not explicitly set by users, set default
|
||||||
# value
|
# value
|
||||||
if self.compilation_config.cudagraph_mode is None:
|
if self.compilation_config.cudagraph_mode is None:
|
||||||
|
|||||||
@ -498,6 +498,10 @@ class CudaPlatformBase(Platform):
|
|||||||
def support_hybrid_kv_cache(cls) -> bool:
|
def support_hybrid_kv_cache(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def support_static_graph_mode(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@ -587,6 +587,13 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def support_static_graph_mode(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Returns if the graph mode is supported by the current platform.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def use_sync_weight_loader(cls) -> bool:
|
def use_sync_weight_loader(cls) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -477,3 +477,7 @@ class RocmPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def support_hybrid_kv_cache(cls) -> bool:
|
def support_hybrid_kv_cache(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def support_static_graph_mode(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|||||||
@ -113,12 +113,9 @@ class XPUPlatform(Platform):
|
|||||||
# lazy import to avoid circular import
|
# lazy import to avoid circular import
|
||||||
from vllm.config import CompilationLevel, CUDAGraphMode
|
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if compilation_config.cudagraph_mode is None or \
|
|
||||||
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \
|
||||||
!= CUDAGraphMode.NONE:
|
"CUDA graph mode should be NONE on XPU"
|
||||||
logger.info("[XPU] CUDA graph is not supported on XPU, disabling "
|
|
||||||
"cudagraphs. Fallback to cudagraph_mode=NONE")
|
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
|
||||||
|
|
||||||
if vllm_config.lora_config is not None:
|
if vllm_config.lora_config is not None:
|
||||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
@ -169,6 +166,10 @@ class XPUPlatform(Platform):
|
|||||||
def support_hybrid_kv_cache(cls) -> bool:
|
def support_hybrid_kv_cache(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def support_static_graph_mode(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_pin_memory_available(cls):
|
def is_pin_memory_available(cls):
|
||||||
return True
|
return True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user