vllm/vllm/utils/mem_utils.py
iAmir97 1d165d6d85
[Chore] Separate out vllm.utils.mem_utils (#27143)
Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com>
Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com>
Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-18 10:06:59 +00:00

233 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import gc
import time
from collections.abc import Generator
from dataclasses import dataclass, field
from functools import cache
import psutil
import torch
import torch.types
from .mem_constants import GiB_bytes
@cache
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
from vllm import _custom_ops as ops
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
return int(max_shared_mem)
def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total
class DeviceMemoryProfiler:
def __init__(self, device: torch.types.Device | None = None):
self.device = device
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
gc.collect()
return current_platform.get_current_memory_usage(self.device)
def __enter__(self):
self.initial_memory = self.current_memory_usage()
# This allows us to call methods of the context manager if needed
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.final_memory = self.current_memory_usage()
self.consumed_memory = self.final_memory - self.initial_memory
# Force garbage collection
gc.collect()
@dataclass
class MemorySnapshot:
"""Memory snapshot."""
torch_peak: int = 0
free_memory: int = 0
total_memory: int = 0
cuda_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0
auto_measure: bool = True
def __post_init__(self):
if self.auto_measure:
self.measure()
def measure(self):
from vllm.platforms import current_platform
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
if (
current_platform.is_cuda()
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
):
# On UMA (Orin, Thor and Spark) platform,
# where both CPU and GPU rely on system memory,
# the cudaMemGetInfo function shows the amount of free system memory
# rather than whats actually available.
# In the case,
# torch.cuda.mem_get_info() only reports "free" memory,
# which can be lower than what is actually
# available due to not including cache memory.
# Theres also a comprehensive reference page
# that explains how you can compute the proper value yourself.
# https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
self.free_memory = psutil.virtual_memory().available
self.cuda_memory = self.total_memory - self.free_memory
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()
self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory,
total_memory=self.total_memory - other.total_memory,
cuda_memory=self.cuda_memory - other.cuda_memory,
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp,
auto_measure=False,
)
@dataclass
class MemoryProfilingResult:
"""Memory profiling result. All numbers are in bytes."""
non_kv_cache_memory: int = 0
torch_peak_increase: int = 0
non_torch_increase: int = 0
weights_memory: float = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0
def __repr__(self) -> str:
return (
f"Memory profiling takes {self.profile_time:.2f} seconds. "
f"Total non KV cache memory: "
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
f"torch peak memory increase: "
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
f"non-torch forward increase memory: "
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB."
)
@contextlib.contextmanager
def memory_profiling(
baseline_snapshot: MemorySnapshot, weights_memory: int
) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
included in the weights_memory because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
2. memory used by torch in the current vLLM instance.
3. memory used in the current vLLM instance, but not by torch.
A quantitive example:
Before creating the current vLLM instance:
category 1: 1 GiB
category 2: 0 GiB
category 3: 0 GiB
After creating the current vLLM instance and loading the model,
(i.e. before profiling):
category 1: 1 GiB
category 2: 2 GiB (model weights take 2 GiB)
category 3: 0.5 GiB (memory used by NCCL)
During profiling (peak):
category 1: 1 GiB
category 2: 4 GiB (peak activation tensors take 2 GiB)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
After profiling:
category 1: 1 GiB
category 2: 3 GiB (after garbage-collecting activation tensors)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
In this case, non-kv cache takes 5 GiB in total, including:
a. 2 GiB used by the model weights (category 2)
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
result = MemoryProfilingResult()
result.before_create = baseline_snapshot
# the part of memory used for holding the model weights
result.weights_memory = weights_memory
result.before_profile.measure()
yield result
gc.collect()
torch.cuda.empty_cache()
result.after_profile.measure()
diff_profile = result.after_profile - result.before_profile
diff_from_create = result.after_profile - result.before_create
result.torch_peak_increase = diff_profile.torch_peak
result.non_torch_increase = diff_from_create.non_torch_memory
result.profile_time = diff_profile.timestamp
non_torch_memory = result.non_torch_increase
peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = (
non_torch_memory + peak_activation_memory + result.weights_memory
) # noqa