mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:25:41 +08:00
Measure model memory usage (#3120)
This commit is contained in:
parent
2daf23ab0c
commit
385da2dae2
@ -3,6 +3,7 @@ import os
|
||||
import socket
|
||||
import subprocess
|
||||
import uuid
|
||||
import gc
|
||||
from platform import uname
|
||||
from typing import List, Tuple, Union
|
||||
from packaging.version import parse, Version
|
||||
@ -309,3 +310,27 @@ def create_kv_caches_with_random(
|
||||
f"Does not support value cache of type {cache_dtype}")
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
||||
|
||||
class measure_cuda_memory:
|
||||
|
||||
def __init__(self, device=None):
|
||||
self.device = device
|
||||
|
||||
def current_memory_usage(self) -> float:
|
||||
# Return the memory usage in bytes.
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
mem = torch.cuda.max_memory_allocated(self.device)
|
||||
return mem
|
||||
|
||||
def __enter__(self):
|
||||
self.initial_memory = self.current_memory_usage()
|
||||
# This allows us to call methods of the context manager if needed
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.final_memory = self.current_memory_usage()
|
||||
self.consumed_memory = self.final_memory - self.initial_memory
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import in_wsl
|
||||
from vllm.utils import in_wsl, measure_cuda_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -85,11 +85,17 @@ class ModelRunner:
|
||||
self.model_config.enforce_eager = True
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config,
|
||||
self.device_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
with measure_cuda_memory() as m:
|
||||
self.model = get_model(self.model_config,
|
||||
self.device_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info(
|
||||
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
|
||||
)
|
||||
|
||||
vocab_size = self.model.config.vocab_size
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user