mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 15:06:58 +08:00
[sleep mode] save memory for on-the-fly quantization (#24731)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
7a1c4025f1
commit
fdb09c77d6
@ -16,8 +16,11 @@ from typing import Any, Callable, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_loaded_library(lib_name) -> Optional[str]:
|
def find_loaded_library(lib_name) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@ -165,6 +168,9 @@ class CuMemAllocator:
|
|||||||
py_d_mem = allocation_handle[2]
|
py_d_mem = allocation_handle[2]
|
||||||
self.pointer_to_data[py_d_mem] = AllocationData(
|
self.pointer_to_data[py_d_mem] = AllocationData(
|
||||||
allocation_handle, self.current_tag)
|
allocation_handle, self.current_tag)
|
||||||
|
logger.debug(
|
||||||
|
"Allocated %s bytes for %s with address %s from cumem allocator",
|
||||||
|
allocation_handle[1], self.current_tag, py_d_mem)
|
||||||
return
|
return
|
||||||
|
|
||||||
def _python_free_callback(self, ptr: int) -> HandleType:
|
def _python_free_callback(self, ptr: int) -> HandleType:
|
||||||
@ -174,6 +180,9 @@ class CuMemAllocator:
|
|||||||
data = self.pointer_to_data.pop(ptr)
|
data = self.pointer_to_data.pop(ptr)
|
||||||
if data.cpu_backup_tensor is not None:
|
if data.cpu_backup_tensor is not None:
|
||||||
data.cpu_backup_tensor = None
|
data.cpu_backup_tensor = None
|
||||||
|
logger.debug(
|
||||||
|
"Freed %s bytes for %s with address %s from cumem allocator",
|
||||||
|
data.handle[1], data.tag, ptr)
|
||||||
return data.handle
|
return data.handle
|
||||||
|
|
||||||
def sleep(
|
def sleep(
|
||||||
@ -197,9 +206,14 @@ class CuMemAllocator:
|
|||||||
|
|
||||||
assert isinstance(offload_tags, tuple)
|
assert isinstance(offload_tags, tuple)
|
||||||
|
|
||||||
|
total_bytes = 0
|
||||||
|
backup_bytes = 0
|
||||||
|
|
||||||
for ptr, data in self.pointer_to_data.items():
|
for ptr, data in self.pointer_to_data.items():
|
||||||
handle = data.handle
|
handle = data.handle
|
||||||
|
total_bytes += handle[1]
|
||||||
if data.tag in offload_tags:
|
if data.tag in offload_tags:
|
||||||
|
backup_bytes += handle[1]
|
||||||
size_in_bytes = handle[1]
|
size_in_bytes = handle[1]
|
||||||
cpu_backup_tensor = torch.empty(
|
cpu_backup_tensor = torch.empty(
|
||||||
size_in_bytes,
|
size_in_bytes,
|
||||||
@ -211,6 +225,12 @@ class CuMemAllocator:
|
|||||||
data.cpu_backup_tensor = cpu_backup_tensor
|
data.cpu_backup_tensor = cpu_backup_tensor
|
||||||
unmap_and_release(handle)
|
unmap_and_release(handle)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"CuMemAllocator: sleep freed %.2f GiB memory in total, of which "
|
||||||
|
"%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
|
||||||
|
"directly.", total_bytes / 1024**3, backup_bytes / 1024**3,
|
||||||
|
(total_bytes - backup_bytes) / 1024**3)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -267,12 +287,17 @@ class CuMemAllocator:
|
|||||||
# when using pluggable allocator, see
|
# when using pluggable allocator, see
|
||||||
# https://github.com/pytorch/pytorch/issues/145168 .
|
# https://github.com/pytorch/pytorch/issues/145168 .
|
||||||
# if we have some memory allocated and then freed,
|
# if we have some memory allocated and then freed,
|
||||||
# the memory will not be released.
|
# the memory will not be released, e.g. in online quantization,
|
||||||
# right now it is fine, because we only use this allocator
|
# where the model is created in higher precision, and then
|
||||||
# during weight loading and kv cache creation, where we only
|
# quantized in lower precision.
|
||||||
# allocate memory.
|
# Find all unused allocations and manually release them.
|
||||||
# TODO: we need to find a way to release the memory,
|
# TODO: we should expose `empty_cache` method in the memory pool.
|
||||||
# i.e. calling torch.cuda.empty_cache()
|
# TODO: ask for help from PyTorch team to expose this method.
|
||||||
|
allocations = data[0].snapshot()
|
||||||
|
for allocation in allocations:
|
||||||
|
if allocation["allocated_size"] == 0:
|
||||||
|
handle = self._python_free_callback(allocation["address"])
|
||||||
|
unmap_and_release(handle)
|
||||||
self.current_tag = old_tag
|
self.current_tag = old_tag
|
||||||
|
|
||||||
def get_current_usage(self) -> int:
|
def get_current_usage(self) -> int:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user