mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:15:31 +08:00
[Chore] Separate out vllm.utils.mem_utils (#27143)
Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com> Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
83004020fd
commit
1d165d6d85
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
|
||||||
from ..utils import create_new_process_for_each_test
|
from ..utils import create_new_process_for_each_test
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import get_max_shared_memory_bytes
|
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
if not current_platform.is_rocm():
|
if not current_platform.is_rocm():
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.v1.core.kv_cache_utils import (
|
from vllm.v1.core.kv_cache_utils import (
|
||||||
generate_scheduler_kv_cache_config,
|
generate_scheduler_kv_cache_config,
|
||||||
get_kv_cache_configs,
|
get_kv_cache_configs,
|
||||||
|
|||||||
@ -46,10 +46,10 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
FlexibleArgumentParser,
|
FlexibleArgumentParser,
|
||||||
GB_bytes,
|
|
||||||
cuda_device_count_stateless,
|
cuda_device_count_stateless,
|
||||||
get_open_port,
|
get_open_port,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.mem_constants import GB_bytes
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
from amdsmi import (
|
from amdsmi import (
|
||||||
|
|||||||
@ -23,7 +23,6 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
|||||||
|
|
||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
FlexibleArgumentParser,
|
FlexibleArgumentParser,
|
||||||
MemorySnapshot,
|
|
||||||
bind_kv_cache,
|
bind_kv_cache,
|
||||||
common_broadcastable_dtype,
|
common_broadcastable_dtype,
|
||||||
current_stream,
|
current_stream,
|
||||||
@ -33,13 +32,13 @@ from vllm.utils import (
|
|||||||
join_host_port,
|
join_host_port,
|
||||||
make_zmq_path,
|
make_zmq_path,
|
||||||
make_zmq_socket,
|
make_zmq_socket,
|
||||||
memory_profiling,
|
|
||||||
sha256,
|
sha256,
|
||||||
split_host_port,
|
split_host_port,
|
||||||
split_zmq_path,
|
split_zmq_path,
|
||||||
unique_filepath,
|
unique_filepath,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||||
from ..utils import create_new_process_for_each_test, flat_product
|
from ..utils import create_new_process_for_each_test, flat_product
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from vllm.multimodal.inputs import (
|
|||||||
PlaceholderRange,
|
PlaceholderRange,
|
||||||
)
|
)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor
|
from vllm.utils import sha256, sha256_cbor
|
||||||
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||||
from vllm.v1.core.kv_cache_utils import (
|
from vllm.v1.core.kv_cache_utils import (
|
||||||
BlockHash,
|
BlockHash,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.config import (
|
|||||||
)
|
)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
||||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||||
from vllm.v1.worker.tpu_model_runner import (
|
from vllm.v1.worker.tpu_model_runner import (
|
||||||
|
|||||||
@ -21,7 +21,8 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import GiB_bytes, update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
||||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.utils import MemorySnapshot
|
from vllm.utils.mem_utils import MemorySnapshot
|
||||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||||
|
|
||||||
# Global queue to track operation order across processes
|
# Global queue to track operation order across processes
|
||||||
|
|||||||
@ -10,7 +10,8 @@ from pydantic.dataclasses import dataclass
|
|||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import GiB_bytes, get_cpu_memory
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
from vllm.utils.mem_utils import get_cpu_memory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config.parallel import ParallelConfig
|
from vllm.config.parallel import ParallelConfig
|
||||||
|
|||||||
@ -81,7 +81,8 @@ from vllm.transformers_utils.config import (
|
|||||||
maybe_override_with_speculators,
|
maybe_override_with_speculators,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor
|
from vllm.utils import FlexibleArgumentParser, get_ip, is_in_ray_actor
|
||||||
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -17,9 +17,9 @@ from vllm.distributed.device_communicators.shm_object_storage import (
|
|||||||
SingleWriterShmRingBuffer,
|
SingleWriterShmRingBuffer,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import GiB_bytes, MiB_bytes
|
|
||||||
from vllm.utils.cache import CacheInfo, LRUCache
|
from vllm.utils.cache import CacheInfo, LRUCache
|
||||||
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
||||||
|
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||||
|
|
||||||
from .inputs import (
|
from .inputs import (
|
||||||
MultiModalBatchedField,
|
MultiModalBatchedField,
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class CpuPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
|
||||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||||
if kv_cache_space is None:
|
if kv_cache_space is None:
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
import gc
|
|
||||||
import getpass
|
import getpass
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
@ -21,7 +20,6 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
@ -38,12 +36,10 @@ from collections import defaultdict
|
|||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Generator,
|
|
||||||
Iterator,
|
Iterator,
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
from concurrent.futures.process import ProcessPoolExecutor
|
from concurrent.futures.process import ProcessPoolExecutor
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
|
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
|
||||||
@ -58,7 +54,6 @@ import psutil
|
|||||||
import regex as re
|
import regex as re
|
||||||
import setproctitle
|
import setproctitle
|
||||||
import torch
|
import torch
|
||||||
import torch.types
|
|
||||||
import yaml
|
import yaml
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -105,17 +100,6 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
|||||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||||
STR_INVALID_VAL: str = "INVALID"
|
STR_INVALID_VAL: str = "INVALID"
|
||||||
|
|
||||||
MB_bytes = 1_000_000
|
|
||||||
"""The number of bytes in one megabyte (MB)."""
|
|
||||||
|
|
||||||
MiB_bytes = 1 << 20
|
|
||||||
"""The number of bytes in one mebibyte (MiB)."""
|
|
||||||
|
|
||||||
GB_bytes = 1_000_000_000
|
|
||||||
"""The number of bytes in one gigabyte (GB)."""
|
|
||||||
|
|
||||||
GiB_bytes = 1 << 30
|
|
||||||
"""The number of bytes in one gibibyte (GiB)."""
|
|
||||||
|
|
||||||
# ANSI color codes
|
# ANSI color codes
|
||||||
CYAN = "\033[1;36m"
|
CYAN = "\033[1;36m"
|
||||||
@ -180,23 +164,6 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
|
||||||
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
|
||||||
# will fail
|
|
||||||
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
|
||||||
return int(max_shared_mem)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_memory() -> int:
|
|
||||||
"""Returns the total CPU memory of the node in bytes."""
|
|
||||||
return psutil.virtual_memory().total
|
|
||||||
|
|
||||||
|
|
||||||
def random_uuid() -> str:
|
def random_uuid() -> str:
|
||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
@ -581,30 +548,6 @@ def is_uva_available() -> bool:
|
|||||||
return is_pin_memory_available()
|
return is_pin_memory_available()
|
||||||
|
|
||||||
|
|
||||||
class DeviceMemoryProfiler:
|
|
||||||
def __init__(self, device: torch.types.Device | None = None):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def current_memory_usage(self) -> float:
|
|
||||||
# Return the memory usage in bytes.
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
return current_platform.get_current_memory_usage(self.device)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.initial_memory = self.current_memory_usage()
|
|
||||||
# This allows us to call methods of the context manager if needed
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.final_memory = self.current_memory_usage()
|
|
||||||
self.consumed_memory = self.final_memory - self.initial_memory
|
|
||||||
|
|
||||||
# Force garbage collection
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
|
|
||||||
def make_ndarray_with_pad(
|
def make_ndarray_with_pad(
|
||||||
x: list[list[T]],
|
x: list[list[T]],
|
||||||
pad: T,
|
pad: T,
|
||||||
@ -1642,183 +1585,6 @@ def kill_process_tree(pid: int):
|
|||||||
os.kill(pid, signal.SIGKILL)
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MemorySnapshot:
|
|
||||||
"""Memory snapshot."""
|
|
||||||
|
|
||||||
torch_peak: int = 0
|
|
||||||
free_memory: int = 0
|
|
||||||
total_memory: int = 0
|
|
||||||
cuda_memory: int = 0
|
|
||||||
torch_memory: int = 0
|
|
||||||
non_torch_memory: int = 0
|
|
||||||
timestamp: float = 0.0
|
|
||||||
auto_measure: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.auto_measure:
|
|
||||||
self.measure()
|
|
||||||
|
|
||||||
def measure(self):
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
# we measure the torch peak memory usage via allocated_bytes,
|
|
||||||
# rather than `torch.cuda.memory_reserved()` .
|
|
||||||
# After `torch.cuda.reset_peak_memory_stats()`,
|
|
||||||
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
|
||||||
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
|
||||||
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
|
|
||||||
|
|
||||||
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
|
|
||||||
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
|
|
||||||
if (
|
|
||||||
current_platform.is_cuda()
|
|
||||||
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
|
|
||||||
):
|
|
||||||
# On UMA (Orin, Thor and Spark) platform,
|
|
||||||
# where both CPU and GPU rely on system memory,
|
|
||||||
# the cudaMemGetInfo function shows the amount of free system memory
|
|
||||||
# rather than what’s actually available.
|
|
||||||
# In the case,
|
|
||||||
# torch.cuda.mem_get_info() only reports "free" memory,
|
|
||||||
# which can be lower than what is actually
|
|
||||||
# available due to not including cache memory.
|
|
||||||
# There’s also a comprehensive reference page
|
|
||||||
# that explains how you can compute the proper value yourself.
|
|
||||||
# https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
|
|
||||||
self.free_memory = psutil.virtual_memory().available
|
|
||||||
|
|
||||||
self.cuda_memory = self.total_memory - self.free_memory
|
|
||||||
|
|
||||||
# torch.cuda.memory_reserved() is how many bytes
|
|
||||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
|
||||||
# this is used to measure the non-torch memory usage
|
|
||||||
self.torch_memory = torch.cuda.memory_reserved()
|
|
||||||
|
|
||||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
|
||||||
self.timestamp = time.time()
|
|
||||||
|
|
||||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
|
||||||
return MemorySnapshot(
|
|
||||||
torch_peak=self.torch_peak - other.torch_peak,
|
|
||||||
free_memory=self.free_memory - other.free_memory,
|
|
||||||
total_memory=self.total_memory - other.total_memory,
|
|
||||||
cuda_memory=self.cuda_memory - other.cuda_memory,
|
|
||||||
torch_memory=self.torch_memory - other.torch_memory,
|
|
||||||
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
|
|
||||||
timestamp=self.timestamp - other.timestamp,
|
|
||||||
auto_measure=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MemoryProfilingResult:
|
|
||||||
"""Memory profiling result. All numbers are in bytes."""
|
|
||||||
|
|
||||||
non_kv_cache_memory: int = 0
|
|
||||||
torch_peak_increase: int = 0
|
|
||||||
non_torch_increase: int = 0
|
|
||||||
weights_memory: float = 0
|
|
||||||
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
||||||
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
||||||
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
||||||
profile_time: float = 0.0
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
|
||||||
f"Total non KV cache memory: "
|
|
||||||
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
|
|
||||||
f"torch peak memory increase: "
|
|
||||||
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
|
|
||||||
f"non-torch forward increase memory: "
|
|
||||||
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
|
|
||||||
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def memory_profiling(
|
|
||||||
baseline_snapshot: MemorySnapshot, weights_memory: int
|
|
||||||
) -> Generator[MemoryProfilingResult, None, None]:
|
|
||||||
"""Memory profiling context manager.
|
|
||||||
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
|
||||||
weights_memory: memory used by PyTorch when loading the model weights.
|
|
||||||
Note that, before loading the model weights, we also initialize the device
|
|
||||||
and distributed environment, which may consume some memory. This part is not
|
|
||||||
included in the weights_memory because PyTorch does not control it.
|
|
||||||
|
|
||||||
The memory in one GPU can be classified into 3 categories:
|
|
||||||
1. memory used by anything other than the current vLLM instance.
|
|
||||||
2. memory used by torch in the current vLLM instance.
|
|
||||||
3. memory used in the current vLLM instance, but not by torch.
|
|
||||||
|
|
||||||
A quantitive example:
|
|
||||||
|
|
||||||
Before creating the current vLLM instance:
|
|
||||||
category 1: 1 GiB
|
|
||||||
category 2: 0 GiB
|
|
||||||
category 3: 0 GiB
|
|
||||||
|
|
||||||
After creating the current vLLM instance and loading the model,
|
|
||||||
(i.e. before profiling):
|
|
||||||
category 1: 1 GiB
|
|
||||||
category 2: 2 GiB (model weights take 2 GiB)
|
|
||||||
category 3: 0.5 GiB (memory used by NCCL)
|
|
||||||
|
|
||||||
During profiling (peak):
|
|
||||||
category 1: 1 GiB
|
|
||||||
category 2: 4 GiB (peak activation tensors take 2 GiB)
|
|
||||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
|
||||||
|
|
||||||
After profiling:
|
|
||||||
category 1: 1 GiB
|
|
||||||
category 2: 3 GiB (after garbage-collecting activation tensors)
|
|
||||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
|
||||||
|
|
||||||
In this case, non-kv cache takes 5 GiB in total, including:
|
|
||||||
a. 2 GiB used by the model weights (category 2)
|
|
||||||
b. 2 GiB reserved for the peak activation tensors (category 2)
|
|
||||||
c. 1 GiB used by non-torch components (category 3)
|
|
||||||
|
|
||||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
|
||||||
|
|
||||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
|
||||||
|
|
||||||
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
|
||||||
""" # noqa
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
result = MemoryProfilingResult()
|
|
||||||
|
|
||||||
result.before_create = baseline_snapshot
|
|
||||||
# the part of memory used for holding the model weights
|
|
||||||
result.weights_memory = weights_memory
|
|
||||||
|
|
||||||
result.before_profile.measure()
|
|
||||||
|
|
||||||
yield result
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
result.after_profile.measure()
|
|
||||||
|
|
||||||
diff_profile = result.after_profile - result.before_profile
|
|
||||||
diff_from_create = result.after_profile - result.before_create
|
|
||||||
result.torch_peak_increase = diff_profile.torch_peak
|
|
||||||
result.non_torch_increase = diff_from_create.non_torch_memory
|
|
||||||
result.profile_time = diff_profile.timestamp
|
|
||||||
|
|
||||||
non_torch_memory = result.non_torch_increase
|
|
||||||
peak_activation_memory = result.torch_peak_increase
|
|
||||||
result.non_kv_cache_memory = (
|
|
||||||
non_torch_memory + peak_activation_memory + result.weights_memory
|
|
||||||
) # noqa
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
|
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
|
||||||
def set_ulimit(target_soft_limit=65535):
|
def set_ulimit(target_soft_limit=65535):
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
|
|||||||
13
vllm/utils/mem_constants.py
Normal file
13
vllm/utils/mem_constants.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
MB_bytes = 1_000_000
|
||||||
|
"""The number of bytes in one megabyte (MB)."""
|
||||||
|
|
||||||
|
MiB_bytes = 1 << 20
|
||||||
|
"""The number of bytes in one mebibyte (MiB)."""
|
||||||
|
|
||||||
|
GB_bytes = 1_000_000_000
|
||||||
|
"""The number of bytes in one gigabyte (GB)."""
|
||||||
|
|
||||||
|
GiB_bytes = 1 << 30
|
||||||
|
"""The number of bytes in one gibibyte (GiB)."""
|
||||||
232
vllm/utils/mem_utils.py
Normal file
232
vllm/utils/mem_utils.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
import torch.types
|
||||||
|
|
||||||
|
from .mem_constants import GiB_bytes
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
||||||
|
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
||||||
|
# will fail
|
||||||
|
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
||||||
|
return int(max_shared_mem)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_memory() -> int:
|
||||||
|
"""Returns the total CPU memory of the node in bytes."""
|
||||||
|
return psutil.virtual_memory().total
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceMemoryProfiler:
|
||||||
|
def __init__(self, device: torch.types.Device | None = None):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def current_memory_usage(self) -> float:
|
||||||
|
# Return the memory usage in bytes.
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
return current_platform.get_current_memory_usage(self.device)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.initial_memory = self.current_memory_usage()
|
||||||
|
# This allows us to call methods of the context manager if needed
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.final_memory = self.current_memory_usage()
|
||||||
|
self.consumed_memory = self.final_memory - self.initial_memory
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemorySnapshot:
|
||||||
|
"""Memory snapshot."""
|
||||||
|
|
||||||
|
torch_peak: int = 0
|
||||||
|
free_memory: int = 0
|
||||||
|
total_memory: int = 0
|
||||||
|
cuda_memory: int = 0
|
||||||
|
torch_memory: int = 0
|
||||||
|
non_torch_memory: int = 0
|
||||||
|
timestamp: float = 0.0
|
||||||
|
auto_measure: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.auto_measure:
|
||||||
|
self.measure()
|
||||||
|
|
||||||
|
def measure(self):
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# we measure the torch peak memory usage via allocated_bytes,
|
||||||
|
# rather than `torch.cuda.memory_reserved()` .
|
||||||
|
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||||
|
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||||
|
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
||||||
|
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
|
||||||
|
|
||||||
|
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
|
||||||
|
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
|
||||||
|
if (
|
||||||
|
current_platform.is_cuda()
|
||||||
|
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
|
||||||
|
):
|
||||||
|
# On UMA (Orin, Thor and Spark) platform,
|
||||||
|
# where both CPU and GPU rely on system memory,
|
||||||
|
# the cudaMemGetInfo function shows the amount of free system memory
|
||||||
|
# rather than what’s actually available.
|
||||||
|
# In the case,
|
||||||
|
# torch.cuda.mem_get_info() only reports "free" memory,
|
||||||
|
# which can be lower than what is actually
|
||||||
|
# available due to not including cache memory.
|
||||||
|
# There’s also a comprehensive reference page
|
||||||
|
# that explains how you can compute the proper value yourself.
|
||||||
|
# https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
|
||||||
|
self.free_memory = psutil.virtual_memory().available
|
||||||
|
|
||||||
|
self.cuda_memory = self.total_memory - self.free_memory
|
||||||
|
|
||||||
|
# torch.cuda.memory_reserved() is how many bytes
|
||||||
|
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||||
|
# this is used to measure the non-torch memory usage
|
||||||
|
self.torch_memory = torch.cuda.memory_reserved()
|
||||||
|
|
||||||
|
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||||
|
self.timestamp = time.time()
|
||||||
|
|
||||||
|
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||||
|
return MemorySnapshot(
|
||||||
|
torch_peak=self.torch_peak - other.torch_peak,
|
||||||
|
free_memory=self.free_memory - other.free_memory,
|
||||||
|
total_memory=self.total_memory - other.total_memory,
|
||||||
|
cuda_memory=self.cuda_memory - other.cuda_memory,
|
||||||
|
torch_memory=self.torch_memory - other.torch_memory,
|
||||||
|
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
|
||||||
|
timestamp=self.timestamp - other.timestamp,
|
||||||
|
auto_measure=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryProfilingResult:
|
||||||
|
"""Memory profiling result. All numbers are in bytes."""
|
||||||
|
|
||||||
|
non_kv_cache_memory: int = 0
|
||||||
|
torch_peak_increase: int = 0
|
||||||
|
non_torch_increase: int = 0
|
||||||
|
weights_memory: float = 0
|
||||||
|
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||||
|
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||||
|
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||||
|
profile_time: float = 0.0
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
||||||
|
f"Total non KV cache memory: "
|
||||||
|
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
|
||||||
|
f"torch peak memory increase: "
|
||||||
|
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
|
||||||
|
f"non-torch forward increase memory: "
|
||||||
|
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
|
||||||
|
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def memory_profiling(
|
||||||
|
baseline_snapshot: MemorySnapshot, weights_memory: int
|
||||||
|
) -> Generator[MemoryProfilingResult, None, None]:
|
||||||
|
"""Memory profiling context manager.
|
||||||
|
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
||||||
|
weights_memory: memory used by PyTorch when loading the model weights.
|
||||||
|
Note that, before loading the model weights, we also initialize the device
|
||||||
|
and distributed environment, which may consume some memory. This part is not
|
||||||
|
included in the weights_memory because PyTorch does not control it.
|
||||||
|
|
||||||
|
The memory in one GPU can be classified into 3 categories:
|
||||||
|
1. memory used by anything other than the current vLLM instance.
|
||||||
|
2. memory used by torch in the current vLLM instance.
|
||||||
|
3. memory used in the current vLLM instance, but not by torch.
|
||||||
|
|
||||||
|
A quantitive example:
|
||||||
|
|
||||||
|
Before creating the current vLLM instance:
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 0 GiB
|
||||||
|
category 3: 0 GiB
|
||||||
|
|
||||||
|
After creating the current vLLM instance and loading the model,
|
||||||
|
(i.e. before profiling):
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 2 GiB (model weights take 2 GiB)
|
||||||
|
category 3: 0.5 GiB (memory used by NCCL)
|
||||||
|
|
||||||
|
During profiling (peak):
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 4 GiB (peak activation tensors take 2 GiB)
|
||||||
|
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||||
|
|
||||||
|
After profiling:
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 3 GiB (after garbage-collecting activation tensors)
|
||||||
|
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||||
|
|
||||||
|
In this case, non-kv cache takes 5 GiB in total, including:
|
||||||
|
a. 2 GiB used by the model weights (category 2)
|
||||||
|
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||||
|
c. 1 GiB used by non-torch components (category 3)
|
||||||
|
|
||||||
|
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
||||||
|
|
||||||
|
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
||||||
|
|
||||||
|
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
||||||
|
""" # noqa
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
result = MemoryProfilingResult()
|
||||||
|
|
||||||
|
result.before_create = baseline_snapshot
|
||||||
|
# the part of memory used for holding the model weights
|
||||||
|
result.weights_memory = weights_memory
|
||||||
|
|
||||||
|
result.before_profile.measure()
|
||||||
|
|
||||||
|
yield result
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
result.after_profile.measure()
|
||||||
|
|
||||||
|
diff_profile = result.after_profile - result.before_profile
|
||||||
|
diff_from_create = result.after_profile - result.before_create
|
||||||
|
result.torch_peak_increase = diff_profile.torch_peak
|
||||||
|
result.non_torch_increase = diff_from_create.non_torch_memory
|
||||||
|
result.profile_time = diff_profile.timestamp
|
||||||
|
|
||||||
|
non_torch_memory = result.non_torch_increase
|
||||||
|
peak_activation_memory = result.torch_peak_increase
|
||||||
|
result.non_kv_cache_memory = (
|
||||||
|
non_torch_memory + peak_activation_memory + result.weights_memory
|
||||||
|
) # noqa
|
||||||
@ -12,7 +12,8 @@ from typing import Any, NewType, TypeAlias
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import GiB_bytes, cdiv, sha256_cbor
|
from vllm.utils import cdiv, sha256_cbor
|
||||||
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
ChunkedLocalAttentionSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
FullAttentionSpec,
|
FullAttentionSpec,
|
||||||
|
|||||||
@ -74,8 +74,6 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
STR_DTYPE_TO_TORCH_DTYPE,
|
STR_DTYPE_TO_TORCH_DTYPE,
|
||||||
DeviceMemoryProfiler,
|
|
||||||
GiB_bytes,
|
|
||||||
cdiv,
|
cdiv,
|
||||||
check_use_alibi,
|
check_use_alibi,
|
||||||
get_dtype_size,
|
get_dtype_size,
|
||||||
@ -85,6 +83,8 @@ from vllm.utils import (
|
|||||||
supports_dynamo,
|
supports_dynamo,
|
||||||
)
|
)
|
||||||
from vllm.utils.jsontree import json_map_leaves
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
|||||||
@ -28,7 +28,8 @@ from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import (
|
from vllm.v1.outputs import (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user