mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 21:21:22 +08:00
[Bugfix] Fix vllm bench ... on CPU-only head nodes (#25283)
Signed-off-by: Aydin Abiar <aydin@anyscale.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Aydin Abiar <aydin@anyscale.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
c1b06fc182
commit
76afe4edf8
@ -8,6 +8,11 @@ to avoid certain eager import breakage."""
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -29,6 +34,22 @@ def main():
|
|||||||
|
|
||||||
cli_env_setup()
|
cli_env_setup()
|
||||||
|
|
||||||
|
# For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] == "bench":
|
||||||
|
logger.debug(
|
||||||
|
"Bench command detected, must ensure current platform is not "
|
||||||
|
"UnspecifiedPlatform to avoid device type inference error"
|
||||||
|
)
|
||||||
|
from vllm import platforms
|
||||||
|
|
||||||
|
if platforms.current_platform.is_unspecified():
|
||||||
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
|
|
||||||
|
platforms.current_platform = CpuPlatform()
|
||||||
|
logger.info(
|
||||||
|
"Unspecified platform detected, switching to CPU Platform instead."
|
||||||
|
)
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="vLLM CLI",
|
description="vLLM CLI",
|
||||||
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
|
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
|
||||||
|
|||||||
@ -261,4 +261,14 @@ def __getattr__(name: str):
|
|||||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||||
|
|
||||||
|
|
||||||
|
def __setattr__(name: str, value):
|
||||||
|
if name == "current_platform":
|
||||||
|
global _current_platform
|
||||||
|
_current_platform = value
|
||||||
|
elif name in globals():
|
||||||
|
globals()[name] = value
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
|
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
|
||||||
|
|||||||
@ -141,6 +141,9 @@ class Platform:
|
|||||||
def is_out_of_tree(self) -> bool:
|
def is_out_of_tree(self) -> bool:
|
||||||
return self._enum == PlatformEnum.OOT
|
return self._enum == PlatformEnum.OOT
|
||||||
|
|
||||||
|
def is_unspecified(self) -> bool:
|
||||||
|
return self._enum == PlatformEnum.UNSPECIFIED
|
||||||
|
|
||||||
def get_max_output_tokens(self, prompt_len: int) -> int:
|
def get_max_output_tokens(self, prompt_len: int) -> int:
|
||||||
return sys.maxsize
|
return sys.maxsize
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user