[torch.compile] decouple compile sizes and cudagraph sizes (#12243)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-24 02:01:30 +08:00 committed by GitHub
parent 3f50c148fd
commit 6e650f56a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 94 additions and 57 deletions

View File

@ -680,7 +680,7 @@ class VllmBackend:
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in capture_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False
runnable: Callable = None # type: ignore
@ -727,8 +727,8 @@ class PiecewiseBackend:
self.compile_sizes: Set[int] = set(
self.compilation_config.compile_sizes)
self.capture_sizes: Set[int] = set(
self.compilation_config.capture_sizes
self.cudagraph_capture_sizes: Set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False
@ -746,11 +746,11 @@ class PiecewiseBackend:
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.capture_sizes):
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.capture_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
)
def check_for_ending_compilation(self):

View File

@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for cudagraph sizes that are
in candidate_compile_sizes, using configurations
in inductor_compile_config.
- candidate_compile_sizes: sizes to compile for inductor.
is compiled. In addition, compile for compile_sizes,
using configurations in inductor_compile_config.
- compile_sizes: sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
splitting_ops: List[str] = Field(default=None) # type: ignore
use_inductor: bool = True
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)
@ -2790,8 +2791,6 @@ class CompilationConfig(BaseModel):
pass_config: PassConfig = Field(default_factory=PassConfig)
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
local_cache_dir: str = PrivateAttr # local cache dir for each rank
# optimization:
@ -2918,43 +2917,47 @@ class CompilationConfig(BaseModel):
from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
def init_with_cudagraph_sizes(self,
cudagraph_capture_sizes: List[int]) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes."""
if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
self.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
self.capture_sizes = self.cudagraph_capture_sizes
# de-duplicate the sizes provided by the config
self.cudagraph_capture_sizes = list(
set(self.cudagraph_capture_sizes))
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)
cudagraph_capture_sizes, self.cudagraph_capture_sizes)
if self.candidate_compile_sizes is None:
self.candidate_compile_sizes = []
self.compile_sizes = [
x for x in self.candidate_compile_sizes if x in self.capture_sizes
]
ignored_sizes = [
x for x in self.candidate_compile_sizes
if x not in self.capture_sizes
]
if ignored_sizes:
logger.warning(("candidate_compile_sizes %s are ignored "
"because they are not cudagraph capture sizes."),
ignored_sizes)
computed_compile_sizes = []
if self.compile_sizes is not None:
# de-duplicate the sizes provided by the config
self.compile_sizes = list(set(self.compile_sizes))
for x in self.compile_sizes:
if isinstance(x, str):
assert x == "cudagraph_capture_sizes", \
"Unrecognized size type in compile_sizes, " \
f"expect 'cudagraph_capture_sizes', got {x}"
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
else:
assert isinstance(x, int)
computed_compile_sizes.append(x)
self.compile_sizes = computed_compile_sizes # type: ignore
# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)
self.max_capture_size = self.capture_sizes[
0] if self.capture_sizes else 0
self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[
0] if self.cudagraph_capture_sizes else 0
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_capture_size + 1)
]
for end, start in zip(self.capture_sizes,
self.capture_sizes[1:] + [0]):
for end, start in zip(self.cudagraph_capture_sizes,
self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
@ -3225,14 +3228,14 @@ class VllmConfig:
However, if users specify the cudagraph capture sizes through
compilation config, we will use the specified sizes instead.
In the end, `vllm_config.compilation_config.capture_sizes` will be the
final sizes to capture cudagraph (in descending order).
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
will be the final sizes to capture cudagraph (in descending order).
During runtime, if batchsize is larger than
`vllm_config.compilation_config.capture_sizes`,
`vllm_config.compilation_config.cudagraph_capture_sizes`,
no cudagraph will be used.
If the batch size is no larger than
`vllm_config.compilation_config.capture_sizes`,
`vllm_config.compilation_config.cudagraph_capture_sizes`,
we can quickly find the padded graph size for a given batch size by
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
"""

View File

@ -120,7 +120,8 @@ class Metrics:
labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.capture_sizes.copy()
buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total",

View File

@ -1,6 +1,6 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
import numpy as np
import torch
@ -128,7 +128,8 @@ class GPUModelRunner:
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.vllm_config.compilation_config.capture_sizes))
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# Cache the device properties.
self.device_properties = torch.cuda.get_device_properties(self.device)
@ -834,10 +835,12 @@ class GPUModelRunner:
@torch.inference_mode()
def _dummy_run(
self,
model: nn.Module,
num_tokens: int,
kv_caches: List[torch.Tensor],
kv_caches: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
@ -963,8 +966,7 @@ class GPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
@ -990,8 +992,8 @@ class GPUModelRunner:
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(self.model, num_tokens, self.kv_caches)
self._dummy_run(self.model, num_tokens, self.kv_caches)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]

View File

@ -206,6 +206,18 @@ class Worker:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by

View File

@ -1256,13 +1256,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@torch.inference_mode()
def profile_run(self) -> None:
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
self._dummy_run(max_num_batched_tokens, max_num_seqs)
def _dummy_run(self,
max_num_batched_tokens: int,
max_num_seqs: int = 1) -> None:
with self.set_in_profile_run():
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = \
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
@ -1491,13 +1497,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
# Only rank 0 should print progress bar during capture
capture_sizes = (
tqdm(
self.vllm_config.compilation_config.capture_sizes,
desc="Capturing CUDA graph shapes",
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.capture_sizes)
for batch_size in capture_sizes:
cudagraph_capture_sizes = (tqdm(
self.vllm_config.compilation_config.
cudagraph_capture_sizes,
desc="Capturing CUDA graph shapes",
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.
cudagraph_capture_sizes)
for batch_size in cudagraph_capture_sizes:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,

View File

@ -323,6 +323,18 @@ class Worker(LocalOrDistributedWorkerBase):
self.gpu_cache)
def _warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by