refactor: abstract graph mode support into platform interface (#25161)

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou 2025-09-22 18:22:29 +08:00 committed by GitHub
parent 4cf71cc88a
commit b6f01bd9a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 7 deletions

View File

@ -503,7 +503,7 @@ class VllmConfig:
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.is_cuda_alike() or current_platform.is_xpu():
if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
# value
if self.compilation_config.cudagraph_mode is None:

View File

@ -498,6 +498,10 @@ class CudaPlatformBase(Platform):
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -587,6 +587,13 @@ class Platform:
"""
return False
@classmethod
def support_static_graph_mode(cls) -> bool:
"""
Returns if the graph mode is supported by the current platform.
"""
return False
@classmethod
def use_sync_weight_loader(cls) -> bool:
"""

View File

@ -477,3 +477,7 @@ class RocmPlatform(Platform):
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True

View File

@ -113,12 +113,9 @@ class XPUPlatform(Platform):
# lazy import to avoid circular import
from vllm.config import CompilationLevel, CUDAGraphMode
compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
!= CUDAGraphMode.NONE:
logger.info("[XPU] CUDA graph is not supported on XPU, disabling "
"cudagraphs. Fallback to cudagraph_mode=NONE")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \
"CUDA graph mode should be NONE on XPU"
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
@ -169,6 +166,10 @@ class XPUPlatform(Platform):
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return False
@classmethod
def is_pin_memory_available(cls):
return True