[Fix] Fix memory profiling when GPU is used by multiple processes (#2863)

This commit is contained in:
Woosuk Kwon 2024-02-13 19:52:34 -08:00 committed by GitHub
parent 0c48b37c31
commit 7e45107f51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -84,6 +84,8 @@ class Worker:
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
@ -126,7 +128,9 @@ class Worker:
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size(
block_size, cache_dtype, self.model_config, self.parallel_config)