diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 375ea79d0e81..42433c717cf2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -91,6 +91,7 @@ _NIXL_SUPPORTED_DEVICE = { ), "tpu": ("cpu",), "xpu": ("cpu",), + "cpu": ("cpu",), } # support for oot platform by providing mapping in current_platform _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) @@ -348,7 +349,13 @@ class NixlConnectorScheduler: + vllm_config.parallel_config.data_parallel_rank ) assert vllm_config.kv_transfer_config is not None - self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" + if current_platform.device_type == "cpu": + self.use_host_buffer = False + else: + self.use_host_buffer = ( + vllm_config.kv_transfer_config.kv_buffer_device == "cpu" + ) + logger.info("Initializing NIXL Scheduler %s", engine_id) # Background thread for handling new handshake requests. @@ -820,7 +827,11 @@ class NixlConnectorWorker: # cpu kv buffer for xfer # used when device memory can not be registered under nixl self.host_xfer_buffers: dict[str, torch.Tensor] = {} - self.use_host_buffer = self.kv_buffer_device == "cpu" + if self.device_type == "cpu": + self.use_host_buffer = False + else: + self.use_host_buffer = self.kv_buffer_device == "cpu" + # support for oot platform which can't register nixl memory # type based on kv_buffer_device nixl_memory_type = current_platform.get_nixl_memory_type() @@ -1021,6 +1032,9 @@ class NixlConnectorWorker: # Set a no-op if the host buffer is not cpu. if self.kv_buffer_device != "cpu": return + # Set a no-op if self.device_type is 'cpu'. + if self.device_type == "cpu": + return assert self.use_host_buffer self.copy_blocks = copy_operation