[misc]refactor Platform.set_device method (#20262)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-07-09 09:39:47 +08:00 committed by GitHub
parent 5eaf570050
commit 0b407479ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 43 additions and 4 deletions

View File

@ -75,6 +75,13 @@ class CpuPlatform(Platform):
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cpu.set_device(device)
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False

View File

@ -77,7 +77,7 @@ class CudaPlatformBase(Platform):
"""
Set the device for the current platform.
"""
super().set_device(device)
torch.cuda.set_device(device)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed

View File

@ -45,6 +45,13 @@ class HpuPlatform(Platform):
def inference_mode(cls):
return torch.no_grad()
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.hpu.set_device(device)
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

View File

@ -305,7 +305,7 @@ class Platform:
"""
Set the device for the current platform.
"""
torch.cuda.set_device(device)
raise NotImplementedError
@classmethod
def pre_register_and_update(cls,

View File

@ -241,6 +241,17 @@ class RocmPlatform(Platform):
logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cuda.set_device(device)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed
_ = torch.zeros(1, device=device)
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls,

View File

@ -55,6 +55,13 @@ class TpuPlatform(Platform):
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.tpu.set_device(device)
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
chip_type, _ = device.get_local_chips()

View File

@ -45,6 +45,13 @@ class XPUPlatform(Platform):
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.xpu.set_device(device)
@classmethod
def get_device_capability(
cls,

View File

@ -130,7 +130,7 @@ class Worker(WorkerBase):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
current_platform.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()

View File

@ -132,7 +132,7 @@ class XPUWorker(Worker):
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
):
self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device)
current_platform.set_device(self.device)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory