[core] overhaul memory profiling and fix backward compatibility (#10511)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-16 13:32:25 -08:00 committed by GitHub
parent efbce85f4d
commit 551603feff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 236 additions and 60 deletions

View File

@ -0,0 +1,25 @@
from vllm import LLM, SamplingParams
def test_gpu_memory_utilization():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# makes sure gpu_memory_utilization is per-instance limit,
# not a global limit
llms = [
LLM(model="facebook/opt-125m",
gpu_memory_utilization=0.3,
enforce_eager=True) for i in range(3)
]
for llm in llms:
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -36,7 +36,7 @@ def run_lmfe(sample_regex):
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.6)
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[

View File

@ -5,11 +5,13 @@ from functools import partial
from typing import AsyncIterator, Tuple
import pytest
import torch
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, merge_async_iterators, supports_kw)
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
from .utils import error_on_warning
from .utils import error_on_warning, fork_new_process_for_each_test
@pytest.mark.asyncio
@ -270,3 +272,41 @@ def test_supports_kw(callable,kw_name,requires_kw_only,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported
@fork_new_process_for_each_test
def test_memory_profiling():
# Fake out some model loading + inference memory usage to test profiling
# Memory used by other processes will show up as cuda usage outside of torch
from vllm.distributed.device_communicators.cuda_wrapper import (
CudaRTLibrary)
lib = CudaRTLibrary()
# 512 MiB allocation outside of this instance
handle1 = lib.cudaMalloc(512 * 1024 * 1024)
baseline_memory_in_bytes = \
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
# load weights
weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
weights_memory_in_bytes=weights_memory_in_bytes) as result:
# make a memory spike, 1 GiB
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
del spike
# Add some extra non-torch memory 256 MiB (simulate NCCL)
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
# Check that the memory usage is within 5% of the expected values
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
assert abs(non_torch_ratio - 1) <= 0.05
assert abs(torch_peak_ratio - 1) <= 0.05
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)

View File

@ -31,10 +31,6 @@ def test_gpu_memory_profiling():
is_driver_worker=True,
)
# Load the model so we can profile it
worker.init_device()
worker.load_model()
# Set 10GiB as the total gpu ram to be device-agnostic
def mock_mem_info():
current_usage = torch.cuda.memory_stats(
@ -46,20 +42,24 @@ def test_gpu_memory_profiling():
from unittest.mock import patch
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
# Load the model so we can profile it
worker.init_device()
worker.load_model()
gpu_blocks, _ = worker.determine_num_available_blocks()
# Peak vram usage by torch should be 0.7077 GiB
# Peak vram usage by torch should be 0.47 GiB
# Model weights take 0.25 GiB
# No memory should be allocated outside of torch
# 9.0 GiB should be the utilization target
# 8.2923 GiB should be available for the KV cache
# 8.28 GiB should be available for the KV cache
block_size = CacheEngine.get_cache_block_size(
engine_config.cache_config, engine_config.model_config,
engine_config.parallel_config)
expected_blocks = (8.2923 * 1024**3) // block_size
expected_blocks = (8.28 * 1024**3) // block_size
# Check within a small tolerance for portability
# Hardware, kernel, or dependency changes could all affect memory
# utilization.
# A 10 block tolerance here should be about 6MB of wiggle room.
assert abs(gpu_blocks - expected_blocks) < 10
# A 100 block tolerance here should be about 60MB of wiggle room.
assert abs(gpu_blocks - expected_blocks) < 100

View File

@ -487,11 +487,12 @@ class EngineArgs:
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a global gpu memory '
'utilization limit, for example if 50%% of the gpu memory is '
'already used before vLLM starts and --gpu-memory-utilization is '
'set to 0.9, then only 40%% of the gpu memory will be allocated '
'to the model executor.')
'will use the default value of 0.9. This is a per-instance '
'limit, and only applies to the current vLLM instance.'
'It does not matter if you have another vLLM instance running '
'on the same GPU. For example, if you have two vLLM instances '
'running on the same GPU, you can set the GPU memory utilization '
'to 0.5 for each instance.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,

View File

@ -23,10 +23,12 @@ import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generic, Hashable, List, Literal, Optional,
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
from uuid import uuid4
import numpy as np
@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int):
# Finally kill the parent
with contextlib.suppress(ProcessLookupError):
os.kill(pid, signal.SIGKILL)
@dataclass
class MemorySnapshot:
"""Memory snapshot."""
torch_peak_in_bytes: int = 0
torch_memory_in_bytes: int = 0
timestamp: float = 0.0
def measure(self):
self.torch_peak_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.peak"]
self.torch_memory_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
"""support a - b"""
return MemorySnapshot(
torch_peak_in_bytes=self.torch_peak_in_bytes -
other.torch_peak_in_bytes,
torch_memory_in_bytes=self.torch_memory_in_bytes -
other.torch_memory_in_bytes,
timestamp=self.timestamp - other.timestamp)
@dataclass
class MemoryProfilingResult:
"""Memory profiling result.
""" # noqa
baseline_memory_in_bytes: int = 0
non_kv_cache_memory_in_bytes: int = 0
torch_peak_increase_in_bytes: int = 0
non_torch_increase_in_bytes: int = 0
weights_memory_in_bytes: float = 0
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0
@contextlib.contextmanager
def memory_profiling(
baseline_memory_in_bytes: int, weights_memory_in_bytes: int
) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
baseline_memory_in_bytes: memory used by all the components other than
the current vLLM instance. It contains: memory used by other processes, memory
used by another vLLM instance in the same process, etc. It is usually measured
before the current vLLM instance initialize the device. And we assume it is
constant during the profiling of the current vLLM instance.
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
included in the weights_memory_in_bytes because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
2. memory used by torch in the current vLLM instance.
3. memory used in the current vLLM instance, but not by torch.
A quantitive example:
Before creating the current vLLM instance:
category 1: 1 GiB
category 2: 0 GiB
category 3: 0 GiB
After creating the current vLLM instance and loading the model,
(i.e. before profiling):
category 1: 1 GiB
category 2: 2 GiB (model weights take 2 GiB)
category 3: 0.5 GiB (memory used by NCCL)
During profiling (peak):
category 1: 1 GiB
category 2: 4 GiB (peak activation tensors take 2 GiB)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
After profiling:
category 1: 1 GiB
category 2: 3 GiB (after garbage-collecting activation tensors)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
In this case, non-kv cache takes 5 GiB in total, including:
a. 2 GiB used by the model weights (category 2)
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
""" # noqa
torch.cuda.reset_peak_memory_stats()
result = MemoryProfilingResult()
result.baseline_memory_in_bytes = baseline_memory_in_bytes
# the part of memory used for holding the model weights
result.weights_memory_in_bytes = weights_memory_in_bytes
result.before_profile.measure()
yield result
gc.collect()
torch.cuda.empty_cache()
result.after_profile.measure()
diff = result.after_profile - result.before_profile
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
current_cuda_memory_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
result.profile_time = diff.timestamp
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa

View File

@ -645,7 +645,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return model_input
def load_model(self) -> None:
return self._base_model_runner.load_model()
self._base_model_runner.load_model()
self.model_memory_usage = self._base_model_runner.model_memory_usage
def save_sharded_state(
self,

View File

@ -1,7 +1,6 @@
"""A GPU worker class."""
import gc
import os
import time
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
@ -22,6 +21,7 @@ from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, memory_profiling
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
@ -192,33 +192,22 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
start_time = time.time()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
torch.cuda.synchronize()
with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
self.init_gpu_memory,
weights_memory_in_bytes=self.model_runner.
model_memory_usage) as result:
self.model_runner.profile_run()
torch.cuda.synchronize()
self._assert_memory_footprint_increased_during_profiling()
# Get the peak memory allocation recorded by torch
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.cuda.empty_cache()
torch_allocated_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
available_kv_cache_memory = (memory_for_current_instance -
result.non_kv_cache_memory_in_bytes)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
@ -233,24 +222,23 @@ class Worker(LocalOrDistributedWorkerBase):
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
end_time = time.time()
logger.info(
"Memory profiling results: "
"duration=%.2f seconds, "
"total_gpu_memory=%.2fGiB, "
"initial_memory_usage=%.2fGiB, "
"peak_torch_memory=%.2fGiB, "
"memory_usage_post_profile=%.2fGiB, "
"non_torch_memory=%.2fGiB, "
"kv_cache_size=%.2fGiB, "
"gpu_memory_utilization=%.2f.", end_time - start_time,
total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3),
non_torch_allocations / (1024**3),
available_kv_cache_memory / (1024**3),
self.cache_config.gpu_memory_utilization)
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
# Final cleanup
if self.model_runner.lora_manager: