diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 0ebfe1c22269a..cb15952f0d2de 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -8,6 +8,11 @@ to avoid certain eager import breakage.""" from __future__ import annotations import importlib.metadata +import sys + +from vllm.logger import init_logger + +logger = init_logger(__name__) def main(): @@ -29,6 +34,22 @@ def main(): cli_env_setup() + # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default + if len(sys.argv) > 1 and sys.argv[1] == "bench": + logger.debug( + "Bench command detected, must ensure current platform is not " + "UnspecifiedPlatform to avoid device type inference error" + ) + from vllm import platforms + + if platforms.current_platform.is_unspecified(): + from vllm.platforms.cpu import CpuPlatform + + platforms.current_platform = CpuPlatform() + logger.info( + "Unspecified platform detected, switching to CPU Platform instead." + ) + parser = FlexibleArgumentParser( description="vLLM CLI", epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 962e1323b7215..d1708ad5c7517 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -261,4 +261,14 @@ def __getattr__(name: str): raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") +def __setattr__(name: str, value): + if name == "current_platform": + global _current_platform + _current_platform = value + elif name in globals(): + globals()[name] = value + else: + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + + __all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"] diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 59bc9173958c4..6dc49f99ac2ad 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -141,6 +141,9 @@ class Platform: def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT + def is_unspecified(self) -> bool: + return self._enum == PlatformEnum.UNSPECIFIED + def get_max_output_tokens(self, prompt_len: int) -> int: return sys.maxsize