[Chore] Factor out logic for requesting initial memory (#30868)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-17 23:32:17 +08:00 committed by GitHub
parent 196cdc3224
commit 2497228ad4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 21 deletions

View File

@ -66,27 +66,43 @@ class MemorySnapshot:
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0
device: torch.types.Device = None
auto_measure: bool = True
def __post_init__(self) -> None:
if self.device is None:
from vllm.platforms import current_platform
device_fn = current_platform.current_device
assert device_fn is not None
self.device_ = torch.device(device_fn())
else:
self.device_ = torch.device(self.device)
if self.auto_measure:
self.measure()
def measure(self) -> None:
from vllm.platforms import current_platform
device = self.device_
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
self.torch_peak = torch.cuda.memory_stats(device).get(
"allocated_bytes.all.peak", 0
)
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
self.free_memory, self.total_memory = torch.cuda.mem_get_info(device)
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
if (
current_platform.is_cuda()
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
and current_platform.get_device_capability(device.index)
in shared_sysmem_device_mem_sms
):
# On UMA (Orin, Thor and Spark) platform,
# where both CPU and GPU rely on system memory,
@ -106,12 +122,18 @@ class MemorySnapshot:
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()
self.torch_memory = torch.cuda.memory_reserved(device)
self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
if self.device_ != other.device_:
raise ValueError(
"The two snapshots should be from the same device! "
f"Found: {self.device_} vs. {other.device_}"
)
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory,
@ -120,6 +142,7 @@ class MemorySnapshot:
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp,
device=self.device_,
auto_measure=False,
)

View File

@ -56,6 +56,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
from .utils import request_memory
logger = init_logger(__name__)
if TYPE_CHECKING:
@ -237,22 +239,8 @@ class Worker(WorkerBase):
torch.cuda.empty_cache()
# take current memory snapshot
self.init_snapshot = MemorySnapshot()
self.requested_memory = (
self.init_snapshot.total_memory
* self.cache_config.gpu_memory_utilization
)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = request_memory(init_snapshot, self.cache_config)
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")

View File

@ -8,13 +8,15 @@ from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
@ -248,6 +250,28 @@ def gather_mm_placeholders(
return placeholders[is_embed]
def request_memory(init_snapshot: MemorySnapshot, cache_config: CacheConfig) -> float:
"""
Calculate the amount of memory required by vLLM, then validate
that the current amount of free memory is sufficient for that.
"""
requested_memory = init_snapshot.total_memory * cache_config.gpu_memory_utilization
if init_snapshot.free_memory < requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device {init_snapshot.device_} "
f"({GiB(init_snapshot.free_memory)}/"
f"{GiB(init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({cache_config.gpu_memory_utilization}, "
f"{GiB(requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
return requested_memory
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],