mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[misc] improve memory profiling (#11809)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
ef68eb28d8
commit
889e662eae
@ -5,6 +5,7 @@ from typing import AsyncIterator, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm_test_utils import monitor
|
||||
|
||||
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
|
||||
get_open_port, memory_profiling, merge_async_iterators,
|
||||
@ -289,8 +290,16 @@ def test_memory_profiling():
|
||||
|
||||
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
|
||||
|
||||
def measure_current_non_torch():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
current_used = total - free
|
||||
current_torch = torch.cuda.memory_reserved()
|
||||
current_non_torch = current_used - current_torch
|
||||
return current_non_torch
|
||||
|
||||
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
|
||||
weights_memory_in_bytes=weights_memory_in_bytes) as result:
|
||||
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
|
||||
monitor(measure_current_non_torch) as monitored_values:
|
||||
# make a memory spike, 1 GiB
|
||||
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
|
||||
del spike
|
||||
@ -298,7 +307,15 @@ def test_memory_profiling():
|
||||
# Add some extra non-torch memory 256 MiB (simulate NCCL)
|
||||
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
|
||||
|
||||
# this is an analytic value, it is exact,
|
||||
# we only have 256 MiB non-torch memory increase
|
||||
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
|
||||
assert measured_diff == 256 * 1024 * 1024
|
||||
|
||||
# Check that the memory usage is within 5% of the expected values
|
||||
# 5% tolerance is caused by PyTorch caching allocator,
|
||||
# we cannot control PyTorch's behavior of its internal buffers,
|
||||
# which causes a small error (<10 MiB in practice)
|
||||
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
|
||||
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
|
||||
assert abs(non_torch_ratio - 1) <= 0.05
|
||||
|
||||
@ -4,5 +4,6 @@ It does not import any vLLM modules.
|
||||
"""
|
||||
|
||||
from .blame import BlameResult, blame
|
||||
from .monitor import MonitoredValues, monitor
|
||||
|
||||
__all__ = ["blame", "BlameResult"]
|
||||
__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"]
|
||||
|
||||
68
tests/vllm_test_utils/vllm_test_utils/monitor.py
Normal file
68
tests/vllm_test_utils/vllm_test_utils/monitor.py
Normal file
@ -0,0 +1,68 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, Generator, Generic, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MonitoredValues(Generic[_T]):
|
||||
values: list[_T] = dataclasses.field(default_factory=list)
|
||||
trace_stacks: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def monitor(
|
||||
measure_func: Callable[[],
|
||||
_T]) -> Generator[MonitoredValues[_T], None, None]:
|
||||
"""
|
||||
Trace the function calls to continuously monitor the change of
|
||||
a value.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
|
||||
def measure_func():
|
||||
... # measure the current value
|
||||
return current_value
|
||||
|
||||
with monitor(measure_func) as monitored_values:
|
||||
# do something
|
||||
|
||||
monitored_values.values # all changes of the values
|
||||
monitored_values.trace_stacks # trace stacks of every change
|
||||
```
|
||||
"""
|
||||
monitored_values = MonitoredValues[_T]()
|
||||
|
||||
def _trace_calls(frame, event, arg=None):
|
||||
nonlocal monitored_values
|
||||
if event in ['line']:
|
||||
# triggered by every line of Python code.
|
||||
# only Python functions will trigger it,
|
||||
# c/cpp functions will not trigger it.
|
||||
try:
|
||||
# Temporarily disable the trace function
|
||||
sys.settrace(None)
|
||||
# do a measurement
|
||||
current_value = measure_func()
|
||||
if len(monitored_values.values
|
||||
) == 0 or current_value != monitored_values.values[-1]:
|
||||
monitored_values.values.append(current_value)
|
||||
monitored_values.trace_stacks.append("".join(
|
||||
traceback.format_stack()))
|
||||
# Re-enable the trace function
|
||||
sys.settrace(_trace_calls)
|
||||
except NameError:
|
||||
# modules are deleted during shutdown
|
||||
pass
|
||||
return _trace_calls
|
||||
|
||||
try:
|
||||
sys.settrace(_trace_calls)
|
||||
yield monitored_values
|
||||
finally:
|
||||
sys.settrace(None)
|
||||
@ -1742,10 +1742,10 @@ class MemorySnapshot:
|
||||
timestamp: float = 0.0
|
||||
|
||||
def measure(self):
|
||||
self.torch_peak_in_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.peak"]
|
||||
self.torch_memory_in_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
|
||||
# torch.cuda.memory_reserved() is how many bytes
|
||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
|
||||
self.timestamp = time.time()
|
||||
|
||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||
@ -1822,10 +1822,10 @@ def memory_profiling(
|
||||
|
||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
|
||||
|
||||
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
|
||||
|
||||
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
|
||||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
|
||||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
|
||||
""" # noqa
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user