diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 676a440a79db8..e999a58320228 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -75,6 +75,13 @@ class CpuPlatform(Platform): def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.cpu.set_device(device) + @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return False diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 50eedfa3c412f..b53d7e71a03eb 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -77,7 +77,7 @@ class CudaPlatformBase(Platform): """ Set the device for the current platform. """ - super().set_device(device) + torch.cuda.set_device(device) # With this trick we can force the device to be set eagerly # see https://github.com/pytorch/pytorch/issues/155668 # for why and when it is needed diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 0b1e2f2327901..3faf481087e45 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -45,6 +45,13 @@ class HpuPlatform(Platform): def inference_mode(cls): return torch.no_grad() + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.hpu.set_device(device) + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index b0ef9905481b4..d3060685e9848 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -305,7 +305,7 @@ class Platform: """ Set the device for the current platform. """ - torch.cuda.set_device(device) + raise NotImplementedError @classmethod def pre_register_and_update(cls, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 31f4699cd1b0c..709d86d6ce863 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -241,6 +241,17 @@ class RocmPlatform(Platform): logger.info("Using ROCmFlashAttention backend.") return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.cuda.set_device(device) + # With this trick we can force the device to be set eagerly + # see https://github.com/pytorch/pytorch/issues/155668 + # for why and when it is needed + _ = torch.zeros(1, device=device) + @classmethod @lru_cache(maxsize=8) def get_device_capability(cls, diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 6810944c848d7..10a7f7c60ee2f 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -55,6 +55,13 @@ class TpuPlatform(Platform): logger.info("Using Pallas V1 backend.") return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.tpu.set_device(device) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: chip_type, _ = device.get_local_chips() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 9bc2e2c57e996..fb69ed36af09c 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -45,6 +45,13 @@ class XPUPlatform(Platform): logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.xpu.set_device(device) + @classmethod def get_device_capability( cls, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d1df0fd959b5e..916052ca5ebff 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -130,7 +130,7 @@ class Worker(WorkerBase): # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) + current_platform.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype) gc.collect() diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 6d1f5749d8b2c..dc52accfbd390 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -132,7 +132,7 @@ class XPUWorker(Worker): if self.device_config.device.type == "xpu" and current_platform.is_xpu( ): self.device = torch.device(f"xpu:{self.local_rank}") - torch.xpu.set_device(self.device) + current_platform.set_device(self.device) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank).total_memory