diff --git a/vllm/distributed/ec_transfer/ec_connector/example_connector.py b/vllm/distributed/ec_transfer/ec_connector/example_connector.py index c9aad9e9fc8f3..3518044ce2e00 100644 --- a/vllm/distributed/ec_transfer/ec_connector/example_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/example_connector.py @@ -73,6 +73,7 @@ class ECExampleConnector(ECConnectorBase): data hashes (`mm_hash`) to encoder cache tensors. kwargs (dict): Additional keyword arguments for the connector. """ + from vllm.platforms import current_platform # Get the metadata metadata: ECConnectorMetadata = self._get_connector_metadata() @@ -91,7 +92,9 @@ class ECExampleConnector(ECConnectorBase): if mm_data.mm_hash in encoder_cache: continue filename = self._generate_filename_debug(mm_data.mm_hash) - ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda() + ec_cache = safetensors.torch.load_file( + filename, device=current_platform.device_type + )["ec_cache"] encoder_cache[mm_data.mm_hash] = ec_cache logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)