mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 22:15:54 +08:00
Fix peak memory profiling (#2031)
This commit is contained in:
parent
3fefe271ec
commit
30bad5c492
@ -40,11 +40,6 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
|||||||
return int(max_shared_mem)
|
return int(max_shared_mem)
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(gpu: int = 0) -> int:
|
|
||||||
"""Returns the total memory of the GPU in bytes."""
|
|
||||||
return torch.cuda.get_device_properties(gpu).total_memory
|
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_memory() -> int:
|
def get_cpu_memory() -> int:
|
||||||
"""Returns the total CPU memory of the node in bytes."""
|
"""Returns the total CPU memory of the node in bytes."""
|
||||||
return psutil.virtual_memory().total
|
return psutil.virtual_memory().total
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
from vllm.utils import get_gpu_memory
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@ -81,7 +80,6 @@ class Worker:
|
|||||||
# Profile the memory usage of the model and get the maximum number of
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
# cache blocks that can be allocated with the remaining free memory.
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
# of the model.
|
# of the model.
|
||||||
@ -90,8 +88,9 @@ class Worker:
|
|||||||
# Calculate the number of blocks that can be allocated with the
|
# Calculate the number of blocks that can be allocated with the
|
||||||
# profiled peak memory.
|
# profiled peak memory.
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
peak_memory = torch.cuda.max_memory_allocated()
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
total_gpu_memory = get_gpu_memory()
|
peak_memory = total_gpu_memory - free_gpu_memory
|
||||||
|
|
||||||
cache_block_size = CacheEngine.get_cache_block_size(
|
cache_block_size = CacheEngine.get_cache_block_size(
|
||||||
block_size, self.model_config, self.parallel_config)
|
block_size, self.model_config, self.parallel_config)
|
||||||
num_gpu_blocks = int(
|
num_gpu_blocks = int(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user