mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:45:16 +08:00
[misc] improve memory profiling (#11809)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
ef68eb28d8
commit
889e662eae
@ -5,6 +5,7 @@ from typing import AsyncIterator, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from vllm_test_utils import monitor
|
||||||
|
|
||||||
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
|
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
|
||||||
get_open_port, memory_profiling, merge_async_iterators,
|
get_open_port, memory_profiling, merge_async_iterators,
|
||||||
@ -289,8 +290,16 @@ def test_memory_profiling():
|
|||||||
|
|
||||||
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
|
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
|
||||||
|
|
||||||
|
def measure_current_non_torch():
|
||||||
|
free, total = torch.cuda.mem_get_info()
|
||||||
|
current_used = total - free
|
||||||
|
current_torch = torch.cuda.memory_reserved()
|
||||||
|
current_non_torch = current_used - current_torch
|
||||||
|
return current_non_torch
|
||||||
|
|
||||||
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
|
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
|
||||||
weights_memory_in_bytes=weights_memory_in_bytes) as result:
|
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
|
||||||
|
monitor(measure_current_non_torch) as monitored_values:
|
||||||
# make a memory spike, 1 GiB
|
# make a memory spike, 1 GiB
|
||||||
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
|
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
|
||||||
del spike
|
del spike
|
||||||
@ -298,7 +307,15 @@ def test_memory_profiling():
|
|||||||
# Add some extra non-torch memory 256 MiB (simulate NCCL)
|
# Add some extra non-torch memory 256 MiB (simulate NCCL)
|
||||||
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
|
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
|
||||||
|
|
||||||
|
# this is an analytic value, it is exact,
|
||||||
|
# we only have 256 MiB non-torch memory increase
|
||||||
|
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
|
||||||
|
assert measured_diff == 256 * 1024 * 1024
|
||||||
|
|
||||||
# Check that the memory usage is within 5% of the expected values
|
# Check that the memory usage is within 5% of the expected values
|
||||||
|
# 5% tolerance is caused by PyTorch caching allocator,
|
||||||
|
# we cannot control PyTorch's behavior of its internal buffers,
|
||||||
|
# which causes a small error (<10 MiB in practice)
|
||||||
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
|
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
|
||||||
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
|
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
|
||||||
assert abs(non_torch_ratio - 1) <= 0.05
|
assert abs(non_torch_ratio - 1) <= 0.05
|
||||||
|
|||||||
@ -4,5 +4,6 @@ It does not import any vLLM modules.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .blame import BlameResult, blame
|
from .blame import BlameResult, blame
|
||||||
|
from .monitor import MonitoredValues, monitor
|
||||||
|
|
||||||
__all__ = ["blame", "BlameResult"]
|
__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"]
|
||||||
|
|||||||
68
tests/vllm_test_utils/vllm_test_utils/monitor.py
Normal file
68
tests/vllm_test_utils/vllm_test_utils/monitor.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from typing import Callable, Generator, Generic, TypeVar
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MonitoredValues(Generic[_T]):
|
||||||
|
values: list[_T] = dataclasses.field(default_factory=list)
|
||||||
|
trace_stacks: list[str] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def monitor(
|
||||||
|
measure_func: Callable[[],
|
||||||
|
_T]) -> Generator[MonitoredValues[_T], None, None]:
|
||||||
|
"""
|
||||||
|
Trace the function calls to continuously monitor the change of
|
||||||
|
a value.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
def measure_func():
|
||||||
|
... # measure the current value
|
||||||
|
return current_value
|
||||||
|
|
||||||
|
with monitor(measure_func) as monitored_values:
|
||||||
|
# do something
|
||||||
|
|
||||||
|
monitored_values.values # all changes of the values
|
||||||
|
monitored_values.trace_stacks # trace stacks of every change
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
monitored_values = MonitoredValues[_T]()
|
||||||
|
|
||||||
|
def _trace_calls(frame, event, arg=None):
|
||||||
|
nonlocal monitored_values
|
||||||
|
if event in ['line']:
|
||||||
|
# triggered by every line of Python code.
|
||||||
|
# only Python functions will trigger it,
|
||||||
|
# c/cpp functions will not trigger it.
|
||||||
|
try:
|
||||||
|
# Temporarily disable the trace function
|
||||||
|
sys.settrace(None)
|
||||||
|
# do a measurement
|
||||||
|
current_value = measure_func()
|
||||||
|
if len(monitored_values.values
|
||||||
|
) == 0 or current_value != monitored_values.values[-1]:
|
||||||
|
monitored_values.values.append(current_value)
|
||||||
|
monitored_values.trace_stacks.append("".join(
|
||||||
|
traceback.format_stack()))
|
||||||
|
# Re-enable the trace function
|
||||||
|
sys.settrace(_trace_calls)
|
||||||
|
except NameError:
|
||||||
|
# modules are deleted during shutdown
|
||||||
|
pass
|
||||||
|
return _trace_calls
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.settrace(_trace_calls)
|
||||||
|
yield monitored_values
|
||||||
|
finally:
|
||||||
|
sys.settrace(None)
|
||||||
@ -1742,10 +1742,10 @@ class MemorySnapshot:
|
|||||||
timestamp: float = 0.0
|
timestamp: float = 0.0
|
||||||
|
|
||||||
def measure(self):
|
def measure(self):
|
||||||
self.torch_peak_in_bytes = torch.cuda.memory_stats(
|
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
|
||||||
)["allocated_bytes.all.peak"]
|
# torch.cuda.memory_reserved() is how many bytes
|
||||||
self.torch_memory_in_bytes = torch.cuda.memory_stats(
|
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||||
)["allocated_bytes.all.current"]
|
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
|
||||||
self.timestamp = time.time()
|
self.timestamp = time.time()
|
||||||
|
|
||||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||||
@ -1822,10 +1822,10 @@ def memory_profiling(
|
|||||||
|
|
||||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
|
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
|
||||||
|
|
||||||
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
|
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
|
||||||
|
|
||||||
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
|
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
|
||||||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
|
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
|
||||||
""" # noqa
|
""" # noqa
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user