mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 02:45:02 +08:00
[torch.compile] decouple compile sizes and cudagraph sizes (#12243)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
3f50c148fd
commit
6e650f56a1
@ -680,7 +680,7 @@ class VllmBackend:
|
|||||||
class ConcreteSizeEntry:
|
class ConcreteSizeEntry:
|
||||||
runtime_shape: int
|
runtime_shape: int
|
||||||
need_to_compile: bool # the size is in compile_sizes
|
need_to_compile: bool # the size is in compile_sizes
|
||||||
use_cudagraph: bool # the size is in capture_sizes
|
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
||||||
|
|
||||||
compiled: bool = False
|
compiled: bool = False
|
||||||
runnable: Callable = None # type: ignore
|
runnable: Callable = None # type: ignore
|
||||||
@ -727,8 +727,8 @@ class PiecewiseBackend:
|
|||||||
|
|
||||||
self.compile_sizes: Set[int] = set(
|
self.compile_sizes: Set[int] = set(
|
||||||
self.compilation_config.compile_sizes)
|
self.compilation_config.compile_sizes)
|
||||||
self.capture_sizes: Set[int] = set(
|
self.cudagraph_capture_sizes: Set[int] = set(
|
||||||
self.compilation_config.capture_sizes
|
self.compilation_config.cudagraph_capture_sizes
|
||||||
) if self.compilation_config.use_cudagraph else set()
|
) if self.compilation_config.use_cudagraph else set()
|
||||||
|
|
||||||
self.first_run_finished = False
|
self.first_run_finished = False
|
||||||
@ -746,11 +746,11 @@ class PiecewiseBackend:
|
|||||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||||
# and updates during the compilation process, so we need to copy it
|
# and updates during the compilation process, so we need to copy it
|
||||||
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
|
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
|
||||||
for shape in self.compile_sizes.union(self.capture_sizes):
|
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||||
runtime_shape=shape,
|
runtime_shape=shape,
|
||||||
need_to_compile=shape in self.compile_sizes,
|
need_to_compile=shape in self.compile_sizes,
|
||||||
use_cudagraph=shape in self.capture_sizes,
|
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_for_ending_compilation(self):
|
def check_for_ending_compilation(self):
|
||||||
|
|||||||
@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
|
|||||||
- use_inductor: whether to use inductor compilation.
|
- use_inductor: whether to use inductor compilation.
|
||||||
- False: inductor compilation is not used. graph runs in eager.
|
- False: inductor compilation is not used. graph runs in eager.
|
||||||
- True: inductor compilation is used. one graph for symbolic shape
|
- True: inductor compilation is used. one graph for symbolic shape
|
||||||
is compiled. In addition, compile for cudagraph sizes that are
|
is compiled. In addition, compile for compile_sizes,
|
||||||
in candidate_compile_sizes, using configurations
|
using configurations in inductor_compile_config.
|
||||||
in inductor_compile_config.
|
- compile_sizes: sizes to compile for inductor. In addition
|
||||||
- candidate_compile_sizes: sizes to compile for inductor.
|
to integers, it also supports "cudagraph_capture_sizes" to
|
||||||
|
specify the sizes for cudagraph capture.
|
||||||
- inductor_compile_config: additional configurations for inductor.
|
- inductor_compile_config: additional configurations for inductor.
|
||||||
- None: use default configurations.
|
- None: use default configurations.
|
||||||
- inductor_passes: additional passes for inductor. It is a dictionary
|
- inductor_passes: additional passes for inductor. It is a dictionary
|
||||||
@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
|
|||||||
splitting_ops: List[str] = Field(default=None) # type: ignore
|
splitting_ops: List[str] = Field(default=None) # type: ignore
|
||||||
|
|
||||||
use_inductor: bool = True
|
use_inductor: bool = True
|
||||||
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
|
compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
|
||||||
inductor_compile_config: Dict = Field(default_factory=dict)
|
inductor_compile_config: Dict = Field(default_factory=dict)
|
||||||
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
@ -2790,8 +2791,6 @@ class CompilationConfig(BaseModel):
|
|||||||
pass_config: PassConfig = Field(default_factory=PassConfig)
|
pass_config: PassConfig = Field(default_factory=PassConfig)
|
||||||
|
|
||||||
# not configurable, computed after init
|
# not configurable, computed after init
|
||||||
compile_sizes: List[int] = PrivateAttr
|
|
||||||
capture_sizes: List[int] = PrivateAttr
|
|
||||||
max_capture_size: int = PrivateAttr
|
max_capture_size: int = PrivateAttr
|
||||||
local_cache_dir: str = PrivateAttr # local cache dir for each rank
|
local_cache_dir: str = PrivateAttr # local cache dir for each rank
|
||||||
# optimization:
|
# optimization:
|
||||||
@ -2918,43 +2917,47 @@ class CompilationConfig(BaseModel):
|
|||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
return VllmBackend(vllm_config)
|
return VllmBackend(vllm_config)
|
||||||
|
|
||||||
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
|
def init_with_cudagraph_sizes(self,
|
||||||
|
cudagraph_capture_sizes: List[int]) -> None:
|
||||||
"""To complete the initialization of config,
|
"""To complete the initialization of config,
|
||||||
we need to know the cudagraph sizes."""
|
we need to know the cudagraph sizes."""
|
||||||
|
|
||||||
if self.cudagraph_capture_sizes is None:
|
if self.cudagraph_capture_sizes is None:
|
||||||
self.capture_sizes = sizes_to_specialize
|
self.cudagraph_capture_sizes = cudagraph_capture_sizes
|
||||||
else:
|
else:
|
||||||
self.capture_sizes = self.cudagraph_capture_sizes
|
# de-duplicate the sizes provided by the config
|
||||||
|
self.cudagraph_capture_sizes = list(
|
||||||
|
set(self.cudagraph_capture_sizes))
|
||||||
logger.info(("cudagraph sizes specified by model runner"
|
logger.info(("cudagraph sizes specified by model runner"
|
||||||
" %s is overridden by config %s"),
|
" %s is overridden by config %s"),
|
||||||
sizes_to_specialize, self.cudagraph_capture_sizes)
|
cudagraph_capture_sizes, self.cudagraph_capture_sizes)
|
||||||
|
|
||||||
if self.candidate_compile_sizes is None:
|
computed_compile_sizes = []
|
||||||
self.candidate_compile_sizes = []
|
if self.compile_sizes is not None:
|
||||||
self.compile_sizes = [
|
# de-duplicate the sizes provided by the config
|
||||||
x for x in self.candidate_compile_sizes if x in self.capture_sizes
|
self.compile_sizes = list(set(self.compile_sizes))
|
||||||
]
|
for x in self.compile_sizes:
|
||||||
ignored_sizes = [
|
if isinstance(x, str):
|
||||||
x for x in self.candidate_compile_sizes
|
assert x == "cudagraph_capture_sizes", \
|
||||||
if x not in self.capture_sizes
|
"Unrecognized size type in compile_sizes, " \
|
||||||
]
|
f"expect 'cudagraph_capture_sizes', got {x}"
|
||||||
if ignored_sizes:
|
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
|
||||||
logger.warning(("candidate_compile_sizes %s are ignored "
|
else:
|
||||||
"because they are not cudagraph capture sizes."),
|
assert isinstance(x, int)
|
||||||
ignored_sizes)
|
computed_compile_sizes.append(x)
|
||||||
|
self.compile_sizes = computed_compile_sizes # type: ignore
|
||||||
|
|
||||||
# sort to make sure cudagraph capture sizes are in descending order
|
# sort to make sure cudagraph capture sizes are in descending order
|
||||||
self.capture_sizes.sort(reverse=True)
|
self.cudagraph_capture_sizes.sort(reverse=True)
|
||||||
self.max_capture_size = self.capture_sizes[
|
self.max_capture_size = self.cudagraph_capture_sizes[
|
||||||
0] if self.capture_sizes else 0
|
0] if self.cudagraph_capture_sizes else 0
|
||||||
|
|
||||||
# pre-compute the mapping from batch size to padded graph size
|
# pre-compute the mapping from batch size to padded graph size
|
||||||
self.bs_to_padded_graph_size = [
|
self.bs_to_padded_graph_size = [
|
||||||
0 for i in range(self.max_capture_size + 1)
|
0 for i in range(self.max_capture_size + 1)
|
||||||
]
|
]
|
||||||
for end, start in zip(self.capture_sizes,
|
for end, start in zip(self.cudagraph_capture_sizes,
|
||||||
self.capture_sizes[1:] + [0]):
|
self.cudagraph_capture_sizes[1:] + [0]):
|
||||||
for bs in range(start, end):
|
for bs in range(start, end):
|
||||||
if bs == start:
|
if bs == start:
|
||||||
self.bs_to_padded_graph_size[bs] = start
|
self.bs_to_padded_graph_size[bs] = start
|
||||||
@ -3225,14 +3228,14 @@ class VllmConfig:
|
|||||||
However, if users specify the cudagraph capture sizes through
|
However, if users specify the cudagraph capture sizes through
|
||||||
compilation config, we will use the specified sizes instead.
|
compilation config, we will use the specified sizes instead.
|
||||||
|
|
||||||
In the end, `vllm_config.compilation_config.capture_sizes` will be the
|
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
||||||
final sizes to capture cudagraph (in descending order).
|
will be the final sizes to capture cudagraph (in descending order).
|
||||||
|
|
||||||
During runtime, if batchsize is larger than
|
During runtime, if batchsize is larger than
|
||||||
`vllm_config.compilation_config.capture_sizes`,
|
`vllm_config.compilation_config.cudagraph_capture_sizes`,
|
||||||
no cudagraph will be used.
|
no cudagraph will be used.
|
||||||
If the batch size is no larger than
|
If the batch size is no larger than
|
||||||
`vllm_config.compilation_config.capture_sizes`,
|
`vllm_config.compilation_config.cudagraph_capture_sizes`,
|
||||||
we can quickly find the padded graph size for a given batch size by
|
we can quickly find the padded graph size for a given batch size by
|
||||||
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
|
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -120,7 +120,8 @@ class Metrics:
|
|||||||
labelnames=labelnames)
|
labelnames=labelnames)
|
||||||
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
|
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
|
||||||
if not vllm_config.model_config.enforce_eager:
|
if not vllm_config.model_config.enforce_eager:
|
||||||
buckets = vllm_config.compilation_config.capture_sizes.copy()
|
buckets = vllm_config.compilation_config.\
|
||||||
|
cudagraph_capture_sizes.copy()
|
||||||
buckets.sort()
|
buckets.sort()
|
||||||
self.histogram_iteration_tokens = self._histogram_cls(
|
self.histogram_iteration_tokens = self._histogram_cls(
|
||||||
name="vllm:iteration_tokens_total",
|
name="vllm:iteration_tokens_total",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -128,7 +128,8 @@ class GPUModelRunner:
|
|||||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||||
# The batch sizes in the config are in descending order.
|
# The batch sizes in the config are in descending order.
|
||||||
self.cudagraph_batch_sizes = list(
|
self.cudagraph_batch_sizes = list(
|
||||||
reversed(self.vllm_config.compilation_config.capture_sizes))
|
reversed(
|
||||||
|
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
|
|
||||||
# Cache the device properties.
|
# Cache the device properties.
|
||||||
self.device_properties = torch.cuda.get_device_properties(self.device)
|
self.device_properties = torch.cuda.get_device_properties(self.device)
|
||||||
@ -834,10 +835,12 @@ class GPUModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
model = self.model
|
||||||
|
if kv_caches is None:
|
||||||
|
kv_caches = self.kv_caches
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
@ -963,8 +966,7 @@ class GPUModelRunner:
|
|||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
|
||||||
dummy_kv_caches)
|
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
logits = logits[:self.max_num_tokens]
|
logits = logits[:self.max_num_tokens]
|
||||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||||
@ -990,8 +992,8 @@ class GPUModelRunner:
|
|||||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
self._dummy_run(num_tokens)
|
||||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
self._dummy_run(num_tokens)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|||||||
@ -206,6 +206,18 @@ class Worker:
|
|||||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
def compile_or_warm_up_model(self) -> None:
|
def compile_or_warm_up_model(self) -> None:
|
||||||
|
# warm up sizes that are not in cudagraph capture sizes,
|
||||||
|
# but users still want to compile for better performance,
|
||||||
|
# e.g. for the max-num-batched token size in chunked prefill.
|
||||||
|
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||||
|
if not self.model_config.enforce_eager:
|
||||||
|
warmup_sizes = [
|
||||||
|
x for x in warmup_sizes if x not in
|
||||||
|
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||||
|
]
|
||||||
|
for size in sorted(warmup_sizes, reverse=True):
|
||||||
|
logger.info("Compile and warming up model for size %d", size)
|
||||||
|
self.model_runner._dummy_run(size)
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
|
|||||||
@ -1256,13 +1256,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
|
max_num_batched_tokens = \
|
||||||
|
self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
self._dummy_run(max_num_batched_tokens, max_num_seqs)
|
||||||
|
|
||||||
|
def _dummy_run(self,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
max_num_seqs: int = 1) -> None:
|
||||||
with self.set_in_profile_run():
|
with self.set_in_profile_run():
|
||||||
# Enable top-k sampling to reflect the accurate memory usage.
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
sampling_params = \
|
sampling_params = \
|
||||||
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
max_num_batched_tokens = \
|
|
||||||
self.scheduler_config.max_num_batched_tokens
|
|
||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
|
||||||
# This represents the maximum number of different requests
|
# This represents the maximum number of different requests
|
||||||
# that will have unique loras, an therefore the max amount of memory
|
# that will have unique loras, an therefore the max amount of memory
|
||||||
# consumption create dummy lora request copies from the lora request
|
# consumption create dummy lora request copies from the lora request
|
||||||
@ -1491,13 +1497,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
for virtual_engine in range(
|
for virtual_engine in range(
|
||||||
self.parallel_config.pipeline_parallel_size):
|
self.parallel_config.pipeline_parallel_size):
|
||||||
# Only rank 0 should print progress bar during capture
|
# Only rank 0 should print progress bar during capture
|
||||||
capture_sizes = (
|
cudagraph_capture_sizes = (tqdm(
|
||||||
tqdm(
|
self.vllm_config.compilation_config.
|
||||||
self.vllm_config.compilation_config.capture_sizes,
|
cudagraph_capture_sizes,
|
||||||
desc="Capturing CUDA graph shapes",
|
desc="Capturing CUDA graph shapes",
|
||||||
) if get_tensor_model_parallel_rank() == 0 else
|
) if get_tensor_model_parallel_rank() == 0 else
|
||||||
self.vllm_config.compilation_config.capture_sizes)
|
self.vllm_config.compilation_config.
|
||||||
for batch_size in capture_sizes:
|
cudagraph_capture_sizes)
|
||||||
|
for batch_size in cudagraph_capture_sizes:
|
||||||
attn_metadata = (
|
attn_metadata = (
|
||||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||||
batch_size,
|
batch_size,
|
||||||
|
|||||||
@ -323,6 +323,18 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
self.gpu_cache)
|
self.gpu_cache)
|
||||||
|
|
||||||
def _warm_up_model(self) -> None:
|
def _warm_up_model(self) -> None:
|
||||||
|
# warm up sizes that are not in cudagraph capture sizes,
|
||||||
|
# but users still want to compile for better performance,
|
||||||
|
# e.g. for the max-num-batched token size in chunked prefill.
|
||||||
|
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||||
|
if not self.model_config.enforce_eager:
|
||||||
|
warmup_sizes = [
|
||||||
|
x for x in warmup_sizes if x not in
|
||||||
|
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||||
|
]
|
||||||
|
for size in sorted(warmup_sizes, reverse=True):
|
||||||
|
logger.info("Compile and warming up model for size %d", size)
|
||||||
|
self.model_runner._dummy_run(size)
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
self.model_runner.capture_model(self.gpu_cache)
|
self.model_runner.capture_model(self.gpu_cache)
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user