use device param in load_model method (#13037)

This commit is contained in:
Zhe Zhang 2025-02-19 16:05:02 +08:00 committed by GitHub
parent 3b05cd4555
commit fdc5df6f54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1107,7 +1107,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
with DeviceMemoryProfiler(self.device) as m:
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory