[core] clean up cudagraph batchsize padding logic (#10996)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-12 22:57:50 -08:00 committed by GitHub
parent 34f1a806d5
commit be39e3cd18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 150 additions and 104 deletions

View File

@ -1,7 +1,7 @@
import pytest import pytest
from tests.utils import multi_gpu_test from tests.utils import multi_gpu_test
from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured # This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because # batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible # tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size( vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)): len(example_prompts)):
example_prompts.append(example_prompts[0]) example_prompts.append(example_prompts[0])

View File

@ -5,7 +5,7 @@ Run `pytest tests/models/test_mamba.py`.
import pytest import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured # This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because # batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible # tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size( vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)): len(example_prompts)):
example_prompts.append(example_prompts[0]) example_prompts.append(example_prompts[0])

View File

@ -4,7 +4,6 @@ from typing import List
import pytest import pytest
import torch import torch
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
@ -548,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder # With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors # input sequences will be padded. Create the expected padded tensors
# accordingly. # accordingly.
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size) graph_batch_size = model_runner.vllm_config.pad_for_cudagraph(
expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list( padded_encoder_seq_lens = encoder_seq_lens + list(

View File

@ -3,7 +3,6 @@ from typing import List
import pytest import pytest
import torch import torch
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
@ -177,7 +176,8 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input.attn_metadata, model_input.attn_metadata.slot_mapping) model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens) assert len(slot_mapping) == len(input_tokens)
expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list)) expected_bs = model_runner.vllm_config.pad_for_cudagraph(
len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.num_prefills == 0 assert attn_metadata.num_prefills == 0

View File

@ -2354,6 +2354,12 @@ class CompilationConfig(BaseModel):
# not configurable, computed after init # not configurable, computed after init
compile_sizes: List[int] = PrivateAttr compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to List[int] for better lookup performance.
bs_to_padded_graph_size: List[int] = PrivateAttr
# keep track of enabled and disabled custom ops # keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
@ -2365,6 +2371,19 @@ class CompilationConfig(BaseModel):
# Map from layer name to the attention cls # Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr static_forward_context: Dict[str, Any] = PrivateAttr
def __repr__(self) -> str:
exclude = {
"static_forward_context",
"enabled_custom_ops",
"disabled_custom_ops",
"compilation_time",
"bs_to_padded_graph_size",
"pass_config",
}
return self.model_dump_json(exclude=exclude, exclude_unset=True)
__str__ = __repr__
@classmethod @classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig": def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config.""" """Parse the CLI value for the compilation config."""
@ -2450,18 +2469,22 @@ class CompilationConfig(BaseModel):
# sort to make sure cudagraph capture sizes are in descending order # sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True) self.capture_sizes.sort(reverse=True)
self.max_capture_size = self.capture_sizes[
0] if self.capture_sizes else 0
# pre-compute the mapping from batch size to padded graph size
_BATCH_SIZE_ALIGNMENT = 8 self.bs_to_padded_graph_size = [
# all the token sizes that **can** be captured by cudagraph. 0 for i in range(self.max_capture_size + 1)
# they can be arbitrarily large. ]
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. for end, start in zip(self.capture_sizes,
# the actual sizes to capture will be determined by the model, self.capture_sizes[1:] + [0]):
# depending on the model's max_num_seqs. for bs in range(start, end):
# NOTE: get_graph_batch_size needs to be updated if this list is changed. if bs == start:
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ self.bs_to_padded_graph_size[bs] = start
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) else:
] self.bs_to_padded_graph_size[bs] = end
self.bs_to_padded_graph_size[
self.max_capture_size] = self.max_capture_size
@dataclass @dataclass
@ -2491,40 +2514,12 @@ class VllmConfig:
init=True) # type: ignore init=True) # type: ignore
instance_id: str = "" instance_id: str = ""
@staticmethod def pad_for_cudagraph(self, batch_size: int) -> int:
def get_graph_batch_size(batch_size: int) -> int: # if batch_size > self.compilation_config.max_capture_size,
"""Returns the padded batch size given actual batch size. # it should raise an IndexError.
# the caller should make sure the batch_size is within the range,
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, # i.e., batch_size <= self.compilation_config.max_capture_size
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... return self.compilation_config.bs_to_padded_graph_size[batch_size]
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]
@staticmethod @staticmethod
def _get_quantization_config( def _get_quantization_config(
@ -2618,27 +2613,7 @@ class VllmConfig:
self.compilation_config.pass_config.enable_reshape = False self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.level = CompilationLevel.PIECEWISE
if not envs.VLLM_USE_V1: self._set_cudagraph_sizes()
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
self.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
if self.cache_config is not None and \ if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \ self.cache_config.cpu_offload_gb > 0 and \
@ -2659,6 +2634,70 @@ class VllmConfig:
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
def _set_cudagraph_sizes(self):
"""
cudagraph batchsize padding logic:
`[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible
batch sizes that cudagraph will capture.
Depending on the engine's configuration of `max_num_seqs`, the
candidate batch sizes to capture cudagraph will shrink to the subset
which just cover the range of `[1, max_num_seqs]`. In the common case,
`max_num_seqs` is 256, and the cudagraph batch sizes will be
`[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`.
However, if users specify the cudagraph capture sizes through
compilation config, we will use the specified sizes instead.
In the end, `vllm_config.compilation_config.capture_sizes` will be the
final sizes to capture cudagraph (in descending order).
During runtime, if batchsize is larger than
`vllm_config.compilation_config.capture_sizes`,
no cudagraph will be used.
If the batch size is no larger than
`vllm_config.compilation_config.capture_sizes`,
we can quickly find the padded graph size for a given batch size by
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
"""
# calculate the default `batch_size_capture_list`
if not envs.VLLM_USE_V1:
batch_size_capture_list = []
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
larger_sizes = [
x for x in possible_sizes
if x >= self.scheduler_config.max_num_seqs
]
if larger_sizes:
max_batchsize_to_capture = larger_sizes[0]
else:
max_batchsize_to_capture = possible_sizes[-1]
# filter out the sizes that are
# larger than max_batchsize_to_capture
batch_size_capture_list = [
size for size in possible_sizes
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
def __str__(self): def __str__(self):
return ( return (
f"model={self.model_config.model!r}," f"model={self.model_config.model!r},"

View File

@ -7,7 +7,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -420,6 +420,17 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
@ -433,15 +444,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if self.mamba_cache is None: if self.mamba_cache is None:
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
num_mamba_layers = self.model_config.get_num_layers_by_block_type( num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba) self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape()) self.max_batch_size, *self._get_mamba_cache_shape())
( (
mamba_cache_tensors, mamba_cache_tensors,
state_indices_tensor, state_indices_tensor,

View File

@ -6,7 +6,7 @@ from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -195,6 +195,17 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors) self.backbone.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids) return self.backbone.get_input_embeddings(input_ids)
@ -208,15 +219,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if self.mamba_cache is None: if self.mamba_cache is None:
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
num_mamba_layers = self.model_config.get_num_layers_by_block_type( num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba) self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape()) self.max_batch_size, *self._get_mamba_cache_shape())
( (
mamba_cache_tensors, mamba_cache_tensors,

View File

@ -1,6 +1,6 @@
import gc import gc
import time import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
@ -459,7 +459,7 @@ class GPUModelRunner:
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
# Add padding to the batch size. # Add padding to the batch size.
num_input_tokens = self._get_padded_batch_size( num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens) num_scheduled_tokens)
else: else:
# Eager mode. # Eager mode.
@ -641,10 +641,3 @@ class GPUModelRunner:
torch.zeros(kv_cache_shape, torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device)) device=self.device))
def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
for size in self.cudagraph_batch_sizes:
if batch_size <= size:
return size
return None

View File

@ -464,7 +464,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# We will be using CUDA graph replay for this decode. # We will be using CUDA graph replay for this decode.
max_len_of_block_table = self.get_max_block_per_batch() max_len_of_block_table = self.get_max_block_per_batch()
batch_size = len(encoder_seq_lens) batch_size = len(encoder_seq_lens)
graph_batch_size = self.vllm_config.get_graph_batch_size( graph_batch_size = self.vllm_config.pad_for_cudagraph(
batch_size) batch_size)
assert graph_batch_size >= batch_size assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size cuda_graph_pad_size = graph_batch_size - batch_size

View File

@ -802,7 +802,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len): max_encoder_seq_len):
return -1 return -1
graph_batch_size = VllmConfig.get_graph_batch_size(batch_size) graph_batch_size = self.runner.vllm_config.pad_for_cudagraph(
batch_size)
assert graph_batch_size >= batch_size assert graph_batch_size >= batch_size
return graph_batch_size - batch_size return graph_batch_size - batch_size
@ -1014,8 +1015,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size( self.max_batchsize_to_capture = \
self.scheduler_config.max_num_seqs) self.vllm_config.compilation_config.max_capture_size
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size) {} for _ in range(self.parallel_config.pipeline_parallel_size)

View File

@ -37,10 +37,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
_BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")