mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:28:42 +08:00
[Bugfix] Fix vllm bench ... on CPU-only head nodes (#25283)
Signed-off-by: Aydin Abiar <aydin@anyscale.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Aydin Abiar <aydin@anyscale.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
c1b06fc182
commit
76afe4edf8
@ -8,6 +8,11 @@ to avoid certain eager import breakage."""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
import sys
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
@ -29,6 +34,22 @@ def main():
|
||||
|
||||
cli_env_setup()
|
||||
|
||||
# For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "bench":
|
||||
logger.debug(
|
||||
"Bench command detected, must ensure current platform is not "
|
||||
"UnspecifiedPlatform to avoid device type inference error"
|
||||
)
|
||||
from vllm import platforms
|
||||
|
||||
if platforms.current_platform.is_unspecified():
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
|
||||
platforms.current_platform = CpuPlatform()
|
||||
logger.info(
|
||||
"Unspecified platform detected, switching to CPU Platform instead."
|
||||
)
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM CLI",
|
||||
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
|
||||
|
||||
@ -261,4 +261,14 @@ def __getattr__(name: str):
|
||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||
|
||||
|
||||
def __setattr__(name: str, value):
|
||||
if name == "current_platform":
|
||||
global _current_platform
|
||||
_current_platform = value
|
||||
elif name in globals():
|
||||
globals()[name] = value
|
||||
else:
|
||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||
|
||||
|
||||
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
|
||||
|
||||
@ -141,6 +141,9 @@ class Platform:
|
||||
def is_out_of_tree(self) -> bool:
|
||||
return self._enum == PlatformEnum.OOT
|
||||
|
||||
def is_unspecified(self) -> bool:
|
||||
return self._enum == PlatformEnum.UNSPECIFIED
|
||||
|
||||
def get_max_output_tokens(self, prompt_len: int) -> int:
|
||||
return sys.maxsize
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user