diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c7814f17375b..78cc352b1630 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1107,7 +1107,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: + with DeviceMemoryProfiler(self.device) as m: self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory