mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 04:01:48 +08:00
70 lines
2.1 KiB
Python
70 lines
2.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from vllm import envs
|
|
from vllm.logger import init_logger
|
|
|
|
from .interface import Platform, PlatformEnum
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
else:
|
|
VllmConfig = None
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class NeuronPlatform(Platform):
|
|
_enum = PlatformEnum.NEURON
|
|
device_name: str = "neuron"
|
|
device_type: str = "neuron"
|
|
ray_device_key: str = "neuron_cores"
|
|
supported_quantization: list[str] = ["neuron_quant"]
|
|
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
|
|
|
|
@classmethod
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
return "neuron"
|
|
|
|
@classmethod
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
parallel_config = vllm_config.parallel_config
|
|
if parallel_config.worker_cls == "auto":
|
|
parallel_config.worker_cls = \
|
|
"vllm.worker.neuron_worker.NeuronWorker"
|
|
|
|
if parallel_config.world_size > 1:
|
|
parallel_config.distributed_executor_backend = "uni"
|
|
|
|
assert (vllm_config.lora_config
|
|
is None), "LoRA is not supported for Neuron backend."
|
|
assert (not vllm_config.speculative_config
|
|
), "Speculative decoding not yet supported for Neuron backend."
|
|
|
|
cache_config = vllm_config.cache_config
|
|
if cache_config:
|
|
# neuron needs block_size = max_model_len
|
|
vllm_config.cache_config.block_size = \
|
|
vllm_config.model_config.max_model_len
|
|
|
|
@classmethod
|
|
def is_pin_memory_available(cls) -> bool:
|
|
logger.warning("Pin memory is not supported on Neuron.")
|
|
return False
|
|
|
|
@classmethod
|
|
def get_device_communicator_cls(cls) -> str:
|
|
if envs.VLLM_USE_V1:
|
|
return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa
|
|
else:
|
|
return Platform.get_device_communicator_cls()
|
|
|
|
@classmethod
|
|
def use_all_gather(cls) -> bool:
|
|
return True
|