[torch.compile] decouple compile sizes and cudagraph sizes (#12243)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-24 02:01:30 +08:00 committed by GitHub
parent 3f50c148fd
commit 6e650f56a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 94 additions and 57 deletions

View File

@ -680,7 +680,7 @@ class VllmBackend:
class ConcreteSizeEntry: class ConcreteSizeEntry:
runtime_shape: int runtime_shape: int
need_to_compile: bool # the size is in compile_sizes need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in capture_sizes use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False compiled: bool = False
runnable: Callable = None # type: ignore runnable: Callable = None # type: ignore
@ -727,8 +727,8 @@ class PiecewiseBackend:
self.compile_sizes: Set[int] = set( self.compile_sizes: Set[int] = set(
self.compilation_config.compile_sizes) self.compilation_config.compile_sizes)
self.capture_sizes: Set[int] = set( self.cudagraph_capture_sizes: Set[int] = set(
self.compilation_config.capture_sizes self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set() ) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False self.first_run_finished = False
@ -746,11 +746,11 @@ class PiecewiseBackend:
# to_be_compiled_sizes tracks the remaining sizes to compile, # to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it # and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.capture_sizes): for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry( self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape, runtime_shape=shape,
need_to_compile=shape in self.compile_sizes, need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.capture_sizes, use_cudagraph=shape in self.cudagraph_capture_sizes,
) )
def check_for_ending_compilation(self): def check_for_ending_compilation(self):

View File

@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
- use_inductor: whether to use inductor compilation. - use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager. - False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape - True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for cudagraph sizes that are is compiled. In addition, compile for compile_sizes,
in candidate_compile_sizes, using configurations using configurations in inductor_compile_config.
in inductor_compile_config. - compile_sizes: sizes to compile for inductor. In addition
- candidate_compile_sizes: sizes to compile for inductor. to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.
- inductor_compile_config: additional configurations for inductor. - inductor_compile_config: additional configurations for inductor.
- None: use default configurations. - None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary - inductor_passes: additional passes for inductor. It is a dictionary
@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
splitting_ops: List[str] = Field(default=None) # type: ignore splitting_ops: List[str] = Field(default=None) # type: ignore
use_inductor: bool = True use_inductor: bool = True
candidate_compile_sizes: Optional[List[int]] = Field(default=None) compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
inductor_compile_config: Dict = Field(default_factory=dict) inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict)
@ -2790,8 +2791,6 @@ class CompilationConfig(BaseModel):
pass_config: PassConfig = Field(default_factory=PassConfig) pass_config: PassConfig = Field(default_factory=PassConfig)
# not configurable, computed after init # not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr max_capture_size: int = PrivateAttr
local_cache_dir: str = PrivateAttr # local cache dir for each rank local_cache_dir: str = PrivateAttr # local cache dir for each rank
# optimization: # optimization:
@ -2918,43 +2917,47 @@ class CompilationConfig(BaseModel):
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config) return VllmBackend(vllm_config)
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): def init_with_cudagraph_sizes(self,
cudagraph_capture_sizes: List[int]) -> None:
"""To complete the initialization of config, """To complete the initialization of config,
we need to know the cudagraph sizes.""" we need to know the cudagraph sizes."""
if self.cudagraph_capture_sizes is None: if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize self.cudagraph_capture_sizes = cudagraph_capture_sizes
else: else:
self.capture_sizes = self.cudagraph_capture_sizes # de-duplicate the sizes provided by the config
self.cudagraph_capture_sizes = list(
set(self.cudagraph_capture_sizes))
logger.info(("cudagraph sizes specified by model runner" logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"), " %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes) cudagraph_capture_sizes, self.cudagraph_capture_sizes)
if self.candidate_compile_sizes is None: computed_compile_sizes = []
self.candidate_compile_sizes = [] if self.compile_sizes is not None:
self.compile_sizes = [ # de-duplicate the sizes provided by the config
x for x in self.candidate_compile_sizes if x in self.capture_sizes self.compile_sizes = list(set(self.compile_sizes))
] for x in self.compile_sizes:
ignored_sizes = [ if isinstance(x, str):
x for x in self.candidate_compile_sizes assert x == "cudagraph_capture_sizes", \
if x not in self.capture_sizes "Unrecognized size type in compile_sizes, " \
] f"expect 'cudagraph_capture_sizes', got {x}"
if ignored_sizes: computed_compile_sizes.extend(self.cudagraph_capture_sizes)
logger.warning(("candidate_compile_sizes %s are ignored " else:
"because they are not cudagraph capture sizes."), assert isinstance(x, int)
ignored_sizes) computed_compile_sizes.append(x)
self.compile_sizes = computed_compile_sizes # type: ignore
# sort to make sure cudagraph capture sizes are in descending order # sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True) self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.capture_sizes[ self.max_capture_size = self.cudagraph_capture_sizes[
0] if self.capture_sizes else 0 0] if self.cudagraph_capture_sizes else 0
# pre-compute the mapping from batch size to padded graph size # pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [ self.bs_to_padded_graph_size = [
0 for i in range(self.max_capture_size + 1) 0 for i in range(self.max_capture_size + 1)
] ]
for end, start in zip(self.capture_sizes, for end, start in zip(self.cudagraph_capture_sizes,
self.capture_sizes[1:] + [0]): self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end): for bs in range(start, end):
if bs == start: if bs == start:
self.bs_to_padded_graph_size[bs] = start self.bs_to_padded_graph_size[bs] = start
@ -3225,14 +3228,14 @@ class VllmConfig:
However, if users specify the cudagraph capture sizes through However, if users specify the cudagraph capture sizes through
compilation config, we will use the specified sizes instead. compilation config, we will use the specified sizes instead.
In the end, `vllm_config.compilation_config.capture_sizes` will be the In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
final sizes to capture cudagraph (in descending order). will be the final sizes to capture cudagraph (in descending order).
During runtime, if batchsize is larger than During runtime, if batchsize is larger than
`vllm_config.compilation_config.capture_sizes`, `vllm_config.compilation_config.cudagraph_capture_sizes`,
no cudagraph will be used. no cudagraph will be used.
If the batch size is no larger than If the batch size is no larger than
`vllm_config.compilation_config.capture_sizes`, `vllm_config.compilation_config.cudagraph_capture_sizes`,
we can quickly find the padded graph size for a given batch size by we can quickly find the padded graph size for a given batch size by
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
""" """

View File

@ -120,7 +120,8 @@ class Metrics:
labelnames=labelnames) labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager: if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.capture_sizes.copy() buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
buckets.sort() buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls( self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total", name="vllm:iteration_tokens_total",

View File

@ -1,6 +1,6 @@
import gc import gc
import time import time
from typing import TYPE_CHECKING, Dict, List, Tuple, cast from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
import numpy as np import numpy as np
import torch import torch
@ -128,7 +128,8 @@ class GPUModelRunner:
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order. # The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed(self.vllm_config.compilation_config.capture_sizes)) reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# Cache the device properties. # Cache the device properties.
self.device_properties = torch.cuda.get_device_properties(self.device) self.device_properties = torch.cuda.get_device_properties(self.device)
@ -834,10 +835,12 @@ class GPUModelRunner:
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
self, self,
model: nn.Module,
num_tokens: int, num_tokens: int,
kv_caches: List[torch.Tensor], kv_caches: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
@ -963,8 +966,7 @@ class GPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape. # Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens, hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens] logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler. # TODO(woosuk): Consider the memory usage of the sampler.
@ -990,8 +992,8 @@ class GPUModelRunner:
for num_tokens in reversed(self.cudagraph_batch_sizes): for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config. for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups): cudagraph_num_of_warmups):
self._dummy_run(self.model, num_tokens, self.kv_caches) self._dummy_run(num_tokens)
self._dummy_run(self.model, num_tokens, self.kv_caches) self._dummy_run(num_tokens)
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]

View File

@ -206,6 +206,18 @@ class Worker:
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model() self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by

View File

@ -1256,13 +1256,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
self._dummy_run(max_num_batched_tokens, max_num_seqs)
def _dummy_run(self,
max_num_batched_tokens: int,
max_num_seqs: int = 1) -> None:
with self.set_in_profile_run(): with self.set_in_profile_run():
# Enable top-k sampling to reflect the accurate memory usage. # Enable top-k sampling to reflect the accurate memory usage.
sampling_params = \ sampling_params = \
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests # This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory # that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request # consumption create dummy lora request copies from the lora request
@ -1491,13 +1497,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
for virtual_engine in range( for virtual_engine in range(
self.parallel_config.pipeline_parallel_size): self.parallel_config.pipeline_parallel_size):
# Only rank 0 should print progress bar during capture # Only rank 0 should print progress bar during capture
capture_sizes = ( cudagraph_capture_sizes = (tqdm(
tqdm( self.vllm_config.compilation_config.
self.vllm_config.compilation_config.capture_sizes, cudagraph_capture_sizes,
desc="Capturing CUDA graph shapes", desc="Capturing CUDA graph shapes",
) if get_tensor_model_parallel_rank() == 0 else ) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.capture_sizes) self.vllm_config.compilation_config.
for batch_size in capture_sizes: cudagraph_capture_sizes)
for batch_size in cudagraph_capture_sizes:
attn_metadata = ( attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch( self.attn_state.graph_capture_get_metadata_for_batch(
batch_size, batch_size,

View File

@ -323,6 +323,18 @@ class Worker(LocalOrDistributedWorkerBase):
self.gpu_cache) self.gpu_cache)
def _warm_up_model(self) -> None: def _warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache) self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by