mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 22:15:29 +08:00
[Fix] Fix memory profiling when GPU is used by multiple processes (#2863)
This commit is contained in:
parent
0c48b37c31
commit
7e45107f51
@ -84,6 +84,8 @@ class Worker:
|
|||||||
torch.cuda.set_device(self.device)
|
torch.cuda.set_device(self.device)
|
||||||
|
|
||||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not support device type: {self.device_config.device}")
|
f"Not support device type: {self.device_config.device}")
|
||||||
@ -126,7 +128,9 @@ class Worker:
|
|||||||
# profiled peak memory.
|
# profiled peak memory.
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
peak_memory = total_gpu_memory - free_gpu_memory
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||||
|
# GPU did not change their memory usage during the profiling.
|
||||||
|
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||||
|
|
||||||
cache_block_size = CacheEngine.get_cache_block_size(
|
cache_block_size = CacheEngine.get_cache_block_size(
|
||||||
block_size, cache_dtype, self.model_config, self.parallel_config)
|
block_size, cache_dtype, self.model_config, self.parallel_config)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user