mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 11:24:35 +08:00
[misc]refactor Platform.set_device method (#20262)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
5eaf570050
commit
0b407479ef
@ -75,6 +75,13 @@ class CpuPlatform(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
|
||||
@ -77,7 +77,7 @@ class CudaPlatformBase(Platform):
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
super().set_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
# With this trick we can force the device to be set eagerly
|
||||
# see https://github.com/pytorch/pytorch/issues/155668
|
||||
# for why and when it is needed
|
||||
|
||||
@ -45,6 +45,13 @@ class HpuPlatform(Platform):
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.hpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
|
||||
|
||||
@ -305,7 +305,7 @@ class Platform:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def pre_register_and_update(cls,
|
||||
|
||||
@ -241,6 +241,17 @@ class RocmPlatform(Platform):
|
||||
logger.info("Using ROCmFlashAttention backend.")
|
||||
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
# With this trick we can force the device to be set eagerly
|
||||
# see https://github.com/pytorch/pytorch/issues/155668
|
||||
# for why and when it is needed
|
||||
_ = torch.zeros(1, device=device)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls,
|
||||
|
||||
@ -55,6 +55,13 @@ class TpuPlatform(Platform):
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.tpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
chip_type, _ = device.get_local_chips()
|
||||
|
||||
@ -45,6 +45,13 @@ class XPUPlatform(Platform):
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
Set the device for the current platform.
|
||||
"""
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
|
||||
@ -130,7 +130,7 @@ class Worker(WorkerBase):
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
|
||||
@ -132,7 +132,7 @@ class XPUWorker(Worker):
|
||||
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
torch.xpu.set_device(self.device)
|
||||
current_platform.set_device(self.device)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user