[Neuron] [Bugfix] Fix neuron startup (#9374)

Co-authored-by: Jerzy Zagorski <jzagorsk@amazon.com>
This commit is contained in:
xendo 2024-10-22 14:51:41 +02:00 committed by GitHub
parent a48e3ec052
commit 9dbcce84a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 37 additions and 18 deletions

View File

@ -26,7 +26,8 @@ with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
if TYPE_CHECKING:
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING or current_platform.is_neuron():
def register_fake(fn):
return lambda name: fn

View File

@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once)
is_hip, is_openvino, is_xpu, print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -215,8 +214,10 @@ class ModelConfig:
self.is_attention_free = self._init_attention_free()
self.has_inner_state = self._init_has_inner_state()
self.override_neuron_config = override_neuron_config if is_neuron(
) else None
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
self.override_neuron_config = None
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
@ -368,7 +369,7 @@ class ModelConfig:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
if is_neuron(
if current_platform.is_neuron(
) and self.quantization not in neuron_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
@ -1112,7 +1113,7 @@ class DeviceConfig:
# Automated device type detection
if current_platform.is_cuda_alike():
self.device_type = "cuda"
elif is_neuron():
elif current_platform.is_neuron():
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"

View File

@ -58,6 +58,13 @@ try:
except Exception:
pass
is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
@ -75,6 +82,9 @@ elif is_xpu:
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
else:
current_platform = UnspecifiedPlatform()

View File

@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
TPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
UNSPECIFIED = enum.auto()
@ -48,6 +49,9 @@ class Platform:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

9
vllm/platforms/neuron.py Normal file
View File

@ -0,0 +1,9 @@
from .interface import Platform, PlatformEnum
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "neuron"

View File

@ -1,10 +1,13 @@
from importlib.util import find_spec
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
HAS_TRITON = find_spec("triton") is not None
# neuron has too old torch
HAS_TRITON = find_spec(
"triton") is not None and not current_platform.is_neuron()
if not HAS_TRITON:
logger.info("Triton not installed; certain GPU-related functions"

View File

@ -327,15 +327,6 @@ def is_openvino() -> bool:
return False
@lru_cache(maxsize=None)
def is_neuron() -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
@lru_cache(maxsize=None)
def is_xpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
@ -786,7 +777,7 @@ def is_pin_memory_available() -> bool:
elif is_xpu():
print_warning_once("Pin memory is not supported on XPU.")
return False
elif is_neuron():
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif current_platform.is_cpu() or is_openvino():