diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a0e099a191109..91b748b54c28f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1303,7 +1303,7 @@ class EngineArgs: # Skip this check if we are running on a non-GPU platform, # or if the device capability is not available # (e.g. in a Ray actor without GPUs). - from vllm.platforms import CpuArchEnum, current_platform + from vllm.platforms import current_platform if (current_platform.is_cuda() and current_platform.get_device_capability() and current_platform.get_device_capability().major < 8): @@ -1445,14 +1445,10 @@ class EngineArgs: _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False - # Non-[CUDA, TPU, x86 CPU] may be supported on V1, - # but off by default for now. - v0_hardware = not any( - (current_platform.is_cuda_alike(), current_platform.is_tpu(), - (current_platform.is_cpu() - and current_platform.get_cpu_architecture() == CpuArchEnum.X86))) - if v0_hardware and _warn_or_fallback( # noqa: SIM103 - current_platform.device_name): + # The platform may be supported on V1, but off by default for now. + if not current_platform.default_v1( # noqa: SIM103 + model_config=model_config) and _warn_or_fallback( + current_platform.device_name): return False ############################################################# diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 1dfd394db608d..106bce162003f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -269,3 +269,11 @@ class CpuPlatform(Platform): model configuration. """ return True + + @classmethod + def default_v1(cls, model_config) -> bool: + """Returns whether the current platform can use v1 by default for the + supplied model configuration. + """ + return cls.supports_v1( + model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86 diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f91f222b25e58..3ff173dcd8c85 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -479,6 +479,13 @@ class Platform: """ return False + @classmethod + def default_v1(cls, model_config: ModelConfig) -> bool: + """ + Returns whether the current platform supports v1 by default. + """ + return cls.supports_v1(model_config) + @classmethod def use_custom_allreduce(cls) -> bool: """