vllm/vllm/platforms/neuron.py
wangxiyuan e88db68cf5
[Platform] platform agnostic for EngineArgs initialization (#11225)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2024-12-16 22:11:06 -08:00

46 lines
1.3 KiB
Python

from typing import TYPE_CHECKING, Optional
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_name: str = "neuron"
device_type: str = "neuron"
supported_quantization: list[str] = ["neuron_quant"]
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "neuron"
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"
cache_config = vllm_config.cache_config
if cache_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
return False