mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 22:24:39 +08:00
[core] overhaul memory profiling and fix backward compatibility (#10511)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
efbce85f4d
commit
551603feff
25
tests/entrypoints/llm/test_gpu_utilization.py
Normal file
25
tests/entrypoints/llm/test_gpu_utilization.py
Normal file
@ -0,0 +1,25 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_gpu_memory_utilization():
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# makes sure gpu_memory_utilization is per-instance limit,
|
||||
# not a global limit
|
||||
llms = [
|
||||
LLM(model="facebook/opt-125m",
|
||||
gpu_memory_utilization=0.3,
|
||||
enforce_eager=True) for i in range(3)
|
||||
]
|
||||
for llm in llms:
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@ -36,7 +36,7 @@ def run_lmfe(sample_regex):
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
gpu_memory_utilization=0.6)
|
||||
gpu_memory_utilization=0.3)
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
|
||||
@ -5,11 +5,13 @@ from functools import partial
|
||||
from typing import AsyncIterator, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
|
||||
get_open_port, merge_async_iterators, supports_kw)
|
||||
get_open_port, memory_profiling, merge_async_iterators,
|
||||
supports_kw)
|
||||
|
||||
from .utils import error_on_warning
|
||||
from .utils import error_on_warning, fork_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -270,3 +272,41 @@ def test_supports_kw(callable,kw_name,requires_kw_only,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs
|
||||
) == is_supported
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_memory_profiling():
|
||||
# Fake out some model loading + inference memory usage to test profiling
|
||||
# Memory used by other processes will show up as cuda usage outside of torch
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import (
|
||||
CudaRTLibrary)
|
||||
lib = CudaRTLibrary()
|
||||
# 512 MiB allocation outside of this instance
|
||||
handle1 = lib.cudaMalloc(512 * 1024 * 1024)
|
||||
|
||||
baseline_memory_in_bytes = \
|
||||
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
|
||||
|
||||
# load weights
|
||||
|
||||
weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)
|
||||
|
||||
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
|
||||
|
||||
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
|
||||
weights_memory_in_bytes=weights_memory_in_bytes) as result:
|
||||
# make a memory spike, 1 GiB
|
||||
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
|
||||
del spike
|
||||
|
||||
# Add some extra non-torch memory 256 MiB (simulate NCCL)
|
||||
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
|
||||
|
||||
# Check that the memory usage is within 5% of the expected values
|
||||
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
|
||||
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
|
||||
assert abs(non_torch_ratio - 1) <= 0.05
|
||||
assert abs(torch_peak_ratio - 1) <= 0.05
|
||||
del weights
|
||||
lib.cudaFree(handle1)
|
||||
lib.cudaFree(handle2)
|
||||
|
||||
@ -31,10 +31,6 @@ def test_gpu_memory_profiling():
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# Load the model so we can profile it
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
|
||||
# Set 10GiB as the total gpu ram to be device-agnostic
|
||||
def mock_mem_info():
|
||||
current_usage = torch.cuda.memory_stats(
|
||||
@ -46,20 +42,24 @@ def test_gpu_memory_profiling():
|
||||
|
||||
from unittest.mock import patch
|
||||
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
|
||||
# Load the model so we can profile it
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
gpu_blocks, _ = worker.determine_num_available_blocks()
|
||||
|
||||
# Peak vram usage by torch should be 0.7077 GiB
|
||||
# Peak vram usage by torch should be 0.47 GiB
|
||||
# Model weights take 0.25 GiB
|
||||
# No memory should be allocated outside of torch
|
||||
# 9.0 GiB should be the utilization target
|
||||
# 8.2923 GiB should be available for the KV cache
|
||||
# 8.28 GiB should be available for the KV cache
|
||||
block_size = CacheEngine.get_cache_block_size(
|
||||
engine_config.cache_config, engine_config.model_config,
|
||||
engine_config.parallel_config)
|
||||
|
||||
expected_blocks = (8.2923 * 1024**3) // block_size
|
||||
expected_blocks = (8.28 * 1024**3) // block_size
|
||||
|
||||
# Check within a small tolerance for portability
|
||||
# Hardware, kernel, or dependency changes could all affect memory
|
||||
# utilization.
|
||||
# A 10 block tolerance here should be about 6MB of wiggle room.
|
||||
assert abs(gpu_blocks - expected_blocks) < 10
|
||||
# A 100 block tolerance here should be about 60MB of wiggle room.
|
||||
assert abs(gpu_blocks - expected_blocks) < 100
|
||||
|
||||
@ -487,11 +487,12 @@ class EngineArgs:
|
||||
help='The fraction of GPU memory to be used for the model '
|
||||
'executor, which can range from 0 to 1. For example, a value of '
|
||||
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
|
||||
'will use the default value of 0.9. This is a global gpu memory '
|
||||
'utilization limit, for example if 50%% of the gpu memory is '
|
||||
'already used before vLLM starts and --gpu-memory-utilization is '
|
||||
'set to 0.9, then only 40%% of the gpu memory will be allocated '
|
||||
'to the model executor.')
|
||||
'will use the default value of 0.9. This is a per-instance '
|
||||
'limit, and only applies to the current vLLM instance.'
|
||||
'It does not matter if you have another vLLM instance running '
|
||||
'on the same GPU. For example, if you have two vLLM instances '
|
||||
'running on the same GPU, you can set the GPU memory utilization '
|
||||
'to 0.5 for each instance.')
|
||||
parser.add_argument(
|
||||
'--num-gpu-blocks-override',
|
||||
type=int,
|
||||
|
||||
125
vllm/utils.py
125
vllm/utils.py
@ -23,10 +23,12 @@ import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generic, Hashable, List, Literal, Optional,
|
||||
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
|
||||
Dict, Generator, Generic, Hashable, List, Literal,
|
||||
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
|
||||
overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int):
|
||||
# Finally kill the parent
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySnapshot:
|
||||
"""Memory snapshot."""
|
||||
torch_peak_in_bytes: int = 0
|
||||
torch_memory_in_bytes: int = 0
|
||||
timestamp: float = 0.0
|
||||
|
||||
def measure(self):
|
||||
self.torch_peak_in_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.peak"]
|
||||
self.torch_memory_in_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
self.timestamp = time.time()
|
||||
|
||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||
"""support a - b"""
|
||||
return MemorySnapshot(
|
||||
torch_peak_in_bytes=self.torch_peak_in_bytes -
|
||||
other.torch_peak_in_bytes,
|
||||
torch_memory_in_bytes=self.torch_memory_in_bytes -
|
||||
other.torch_memory_in_bytes,
|
||||
timestamp=self.timestamp - other.timestamp)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryProfilingResult:
|
||||
"""Memory profiling result.
|
||||
""" # noqa
|
||||
baseline_memory_in_bytes: int = 0
|
||||
non_kv_cache_memory_in_bytes: int = 0
|
||||
torch_peak_increase_in_bytes: int = 0
|
||||
non_torch_increase_in_bytes: int = 0
|
||||
weights_memory_in_bytes: float = 0
|
||||
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
profile_time: float = 0.0
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def memory_profiling(
|
||||
baseline_memory_in_bytes: int, weights_memory_in_bytes: int
|
||||
) -> Generator[MemoryProfilingResult, None, None]:
|
||||
"""Memory profiling context manager.
|
||||
baseline_memory_in_bytes: memory used by all the components other than
|
||||
the current vLLM instance. It contains: memory used by other processes, memory
|
||||
used by another vLLM instance in the same process, etc. It is usually measured
|
||||
before the current vLLM instance initialize the device. And we assume it is
|
||||
constant during the profiling of the current vLLM instance.
|
||||
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
|
||||
Note that, before loading the model weights, we also initialize the device
|
||||
and distributed environment, which may consume some memory. This part is not
|
||||
included in the weights_memory_in_bytes because PyTorch does not control it.
|
||||
|
||||
The memory in one GPU can be classified into 3 categories:
|
||||
1. memory used by anything other than the current vLLM instance.
|
||||
2. memory used by torch in the current vLLM instance.
|
||||
3. memory used in the current vLLM instance, but not by torch.
|
||||
|
||||
A quantitive example:
|
||||
|
||||
Before creating the current vLLM instance:
|
||||
category 1: 1 GiB
|
||||
category 2: 0 GiB
|
||||
category 3: 0 GiB
|
||||
|
||||
After creating the current vLLM instance and loading the model,
|
||||
(i.e. before profiling):
|
||||
category 1: 1 GiB
|
||||
category 2: 2 GiB (model weights take 2 GiB)
|
||||
category 3: 0.5 GiB (memory used by NCCL)
|
||||
|
||||
During profiling (peak):
|
||||
category 1: 1 GiB
|
||||
category 2: 4 GiB (peak activation tensors take 2 GiB)
|
||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||
|
||||
After profiling:
|
||||
category 1: 1 GiB
|
||||
category 2: 3 GiB (after garbage-collecting activation tensors)
|
||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||
|
||||
In this case, non-kv cache takes 5 GiB in total, including:
|
||||
a. 2 GiB used by the model weights (category 2)
|
||||
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||
c. 1 GiB used by non-torch components (category 3)
|
||||
|
||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
|
||||
|
||||
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
|
||||
|
||||
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
|
||||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
|
||||
""" # noqa
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
result = MemoryProfilingResult()
|
||||
|
||||
result.baseline_memory_in_bytes = baseline_memory_in_bytes
|
||||
# the part of memory used for holding the model weights
|
||||
result.weights_memory_in_bytes = weights_memory_in_bytes
|
||||
|
||||
result.before_profile.measure()
|
||||
|
||||
yield result
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
result.after_profile.measure()
|
||||
|
||||
diff = result.after_profile - result.before_profile
|
||||
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
|
||||
current_cuda_memory_bytes = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
|
||||
result.profile_time = diff.timestamp
|
||||
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
|
||||
|
||||
@ -645,7 +645,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
return model_input
|
||||
|
||||
def load_model(self) -> None:
|
||||
return self._base_model_runner.load_model()
|
||||
self._base_model_runner.load_model()
|
||||
self.model_memory_usage = self._base_model_runner.model_memory_usage
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
@ -22,6 +21,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import GiB_bytes, memory_profiling
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
@ -192,33 +192,22 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
start_time = time.time()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
torch.cuda.synchronize()
|
||||
with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
|
||||
self.init_gpu_memory,
|
||||
weights_memory_in_bytes=self.model_runner.
|
||||
model_memory_usage) as result:
|
||||
self.model_runner.profile_run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
self._assert_memory_footprint_increased_during_profiling()
|
||||
|
||||
# Get the peak memory allocation recorded by torch
|
||||
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||
|
||||
# Check for any memory left around that may have been allocated on the
|
||||
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||
# GB during a forward pass
|
||||
torch.cuda.empty_cache()
|
||||
torch_allocated_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
memory_for_current_instance = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
available_kv_cache_memory = (memory_for_current_instance -
|
||||
result.non_kv_cache_memory_in_bytes)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
@ -233,24 +222,23 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
"Memory profiling results: "
|
||||
"duration=%.2f seconds, "
|
||||
"total_gpu_memory=%.2fGiB, "
|
||||
"initial_memory_usage=%.2fGiB, "
|
||||
"peak_torch_memory=%.2fGiB, "
|
||||
"memory_usage_post_profile=%.2fGiB, "
|
||||
"non_torch_memory=%.2fGiB, "
|
||||
"kv_cache_size=%.2fGiB, "
|
||||
"gpu_memory_utilization=%.2f.", end_time - start_time,
|
||||
total_gpu_memory / (1024**3),
|
||||
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
|
||||
(peak_memory - non_torch_allocations) / (1024**3),
|
||||
total_allocated_bytes / (1024**3),
|
||||
non_torch_allocations / (1024**3),
|
||||
available_kv_cache_memory / (1024**3),
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
|
||||
"the current vLLM instance can use "
|
||||
"total_gpu_memory "
|
||||
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
|
||||
" x gpu_memory_utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
|
||||
logger.info(msg)
|
||||
|
||||
# Final cleanup
|
||||
if self.model_runner.lora_manager:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user