Measure model memory usage (#3120)

This commit is contained in:
Michael Goin 2024-03-07 11:42:42 -08:00 committed by GitHub
parent 2daf23ab0c
commit 385da2dae2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 6 deletions

View File

@ -3,6 +3,7 @@ import os
import socket import socket
import subprocess import subprocess
import uuid import uuid
import gc
from platform import uname from platform import uname
from typing import List, Tuple, Union from typing import List, Tuple, Union
from packaging.version import parse, Version from packaging.version import parse, Version
@ -309,3 +310,27 @@ def create_kv_caches_with_random(
f"Does not support value cache of type {cache_dtype}") f"Does not support value cache of type {cache_dtype}")
value_caches.append(value_cache) value_caches.append(value_cache)
return key_caches, value_caches return key_caches, value_caches
class measure_cuda_memory:
def __init__(self, device=None):
self.device = device
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
return mem
def __enter__(self):
self.initial_memory = self.current_memory_usage()
# This allows us to call methods of the context manager if needed
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.final_memory = self.current_memory_usage()
self.consumed_memory = self.final_memory - self.initial_memory
# Force garbage collection
gc.collect()

View File

@ -21,7 +21,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl from vllm.utils import in_wsl, measure_cuda_memory
logger = init_logger(__name__) logger = init_logger(__name__)
@ -85,12 +85,18 @@ class ModelRunner:
self.model_config.enforce_eager = True self.model_config.enforce_eager = True
def load_model(self) -> None: def load_model(self) -> None:
with measure_cuda_memory() as m:
self.model = get_model(self.model_config, self.model = get_model(self.model_config,
self.device_config, self.device_config,
lora_config=self.lora_config, lora_config=self.lora_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config) scheduler_config=self.scheduler_config)
self.model_memory_usage = m.consumed_memory
logger.info(
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
)
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
if self.lora_config: if self.lora_config: