mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 00:54:50 +08:00
[V1] Fix Compilation config & Enable CUDA graph by default (#10528)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
7560ae5caf
commit
f9310cbd0c
@ -2370,7 +2370,7 @@ class VllmConfig:
|
||||
|
||||
if self.compilation_config is None:
|
||||
self.compilation_config = CompilationConfig()
|
||||
if envs.VLLM_USE_V1:
|
||||
if envs.VLLM_USE_V1 and not self.model_config.enforce_eager:
|
||||
# NOTE(woosuk): Currently, we use inductor because the piecewise
|
||||
# CUDA graphs do not work properly with the custom CUDA kernels.
|
||||
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
||||
@ -2380,6 +2380,7 @@ class VllmConfig:
|
||||
self.compilation_config.use_inductor = True
|
||||
self.compilation_config.pass_config.enable_fusion = False
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
@ -515,7 +516,25 @@ class GPUModelRunner:
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
model: nn.Module,
|
||||
num_tokens: int,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
with set_forward_context(None):
|
||||
hidden_states = model(
|
||||
input_ids=None,
|
||||
positions=self.positions[:num_tokens],
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=self.inputs_embeds[:num_tokens])
|
||||
return hidden_states
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# TODO(woosuk): Profile the max memory usage of the encoder and
|
||||
# the encoder cache.
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value `None`.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
@ -527,23 +546,17 @@ class GPUModelRunner:
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(self.num_attn_layers)
|
||||
]
|
||||
with set_forward_context(None): # noqa: SIM117
|
||||
with set_compile_context(self.cudagraph_batch_sizes):
|
||||
# Trigger compilation for general shape.
|
||||
model(input_ids=None,
|
||||
positions=self.positions,
|
||||
kv_caches=dummy_kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=self.inputs_embeds)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# TODO(woosuk): Profile the max memory usage of the encoder and
|
||||
# the encoder cache.
|
||||
self._dummy_run(self.model, self.max_num_tokens)
|
||||
with set_compile_context(self.cudagraph_batch_sizes):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
logits = logits[:self.max_num_tokens]
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits
|
||||
gc.collect()
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self) -> None:
|
||||
if not self.use_cuda_graph:
|
||||
logger.warning(
|
||||
@ -554,18 +567,11 @@ class GPUModelRunner:
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with set_forward_context(None):
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
self.model(
|
||||
input_ids=None,
|
||||
positions=self.positions[:num_tokens],
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=self.inputs_embeds[:num_tokens],
|
||||
)
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
@ -105,35 +105,48 @@ class Worker:
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
_, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.cuda.synchronize()
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||
assert peak_memory > 0, (
|
||||
assert self.init_gpu_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
# Get the peak memory allocation recorded by torch
|
||||
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||
|
||||
# Check for any memory left around that may have been allocated on the
|
||||
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||
# GB during a forward pass
|
||||
torch.cuda.empty_cache()
|
||||
torch_allocated_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
cache_block_size = _get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
# if self.model_runner.lora_manager:
|
||||
# self.model_runner.remove_all_loras()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return num_gpu_blocks, 0
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user