mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[core] clean up cudagraph batchsize padding logic (#10996)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
34f1a806d5
commit
be39e3cd18
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == VllmConfig.get_graph_batch_size(
|
||||
vllm_config = EngineArgs(model=model).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ Run `pytest tests/models/test_mamba.py`.
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == VllmConfig.get_graph_batch_size(
|
||||
vllm_config = EngineArgs(model=model).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import List
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
@ -548,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
|
||||
# With CUDA Graph capture and replay enabled, the decoder and encoder
|
||||
# input sequences will be padded. Create the expected padded tensors
|
||||
# accordingly.
|
||||
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
|
||||
graph_batch_size = model_runner.vllm_config.pad_for_cudagraph(
|
||||
expanded_batch_size)
|
||||
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
|
||||
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
|
||||
padded_encoder_seq_lens = encoder_seq_lens + list(
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import List
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -177,7 +176,8 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
|
||||
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
|
||||
len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills == 0
|
||||
|
||||
171
vllm/config.py
171
vllm/config.py
@ -2354,6 +2354,12 @@ class CompilationConfig(BaseModel):
|
||||
# not configurable, computed after init
|
||||
compile_sizes: List[int] = PrivateAttr
|
||||
capture_sizes: List[int] = PrivateAttr
|
||||
max_capture_size: int = PrivateAttr
|
||||
# optimization:
|
||||
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
|
||||
# since we know all keys are in a range [0, max_capture_size],
|
||||
# we can optimize it to List[int] for better lookup performance.
|
||||
bs_to_padded_graph_size: List[int] = PrivateAttr
|
||||
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
@ -2365,6 +2371,19 @@ class CompilationConfig(BaseModel):
|
||||
# Map from layer name to the attention cls
|
||||
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||
|
||||
def __repr__(self) -> str:
|
||||
exclude = {
|
||||
"static_forward_context",
|
||||
"enabled_custom_ops",
|
||||
"disabled_custom_ops",
|
||||
"compilation_time",
|
||||
"bs_to_padded_graph_size",
|
||||
"pass_config",
|
||||
}
|
||||
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
@ -2450,18 +2469,22 @@ class CompilationConfig(BaseModel):
|
||||
|
||||
# sort to make sure cudagraph capture sizes are in descending order
|
||||
self.capture_sizes.sort(reverse=True)
|
||||
self.max_capture_size = self.capture_sizes[
|
||||
0] if self.capture_sizes else 0
|
||||
|
||||
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
# all the token sizes that **can** be captured by cudagraph.
|
||||
# they can be arbitrarily large.
|
||||
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
|
||||
# the actual sizes to capture will be determined by the model,
|
||||
# depending on the model's max_num_seqs.
|
||||
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
|
||||
]
|
||||
# pre-compute the mapping from batch size to padded graph size
|
||||
self.bs_to_padded_graph_size = [
|
||||
0 for i in range(self.max_capture_size + 1)
|
||||
]
|
||||
for end, start in zip(self.capture_sizes,
|
||||
self.capture_sizes[1:] + [0]):
|
||||
for bs in range(start, end):
|
||||
if bs == start:
|
||||
self.bs_to_padded_graph_size[bs] = start
|
||||
else:
|
||||
self.bs_to_padded_graph_size[bs] = end
|
||||
self.bs_to_padded_graph_size[
|
||||
self.max_capture_size] = self.max_capture_size
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -2491,40 +2514,12 @@ class VllmConfig:
|
||||
init=True) # type: ignore
|
||||
instance_id: str = ""
|
||||
|
||||
@staticmethod
|
||||
def get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
||||
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
||||
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
||||
"""
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
@staticmethod
|
||||
def get_max_graph_batch_size(max_num_seqs: int) -> int:
|
||||
"""
|
||||
max_num_seqs: Maximum number of sequences in a batch.
|
||||
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
|
||||
|
||||
pad the max_num_seqs if necessary by calling get_graph_batch_size,
|
||||
which will deal with some edge cases like 1, 2, 4.
|
||||
|
||||
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
|
||||
size. if not, it means the padded size is larger than the largest size
|
||||
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
|
||||
_BATCH_SIZES_TO_CAPTURE.
|
||||
"""
|
||||
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
|
||||
if padded_size in _BATCH_SIZES_TO_CAPTURE:
|
||||
return padded_size
|
||||
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
return _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
def pad_for_cudagraph(self, batch_size: int) -> int:
|
||||
# if batch_size > self.compilation_config.max_capture_size,
|
||||
# it should raise an IndexError.
|
||||
# the caller should make sure the batch_size is within the range,
|
||||
# i.e., batch_size <= self.compilation_config.max_capture_size
|
||||
return self.compilation_config.bs_to_padded_graph_size[batch_size]
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
@ -2618,27 +2613,7 @@ class VllmConfig:
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
max_batchsize_to_capture = 0
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
max_batchsize_to_capture = \
|
||||
self.get_max_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
batch_size_capture_list = [
|
||||
size for size in _BATCH_SIZES_TO_CAPTURE
|
||||
if size <= max_batchsize_to_capture
|
||||
]
|
||||
else:
|
||||
batch_size_capture_list = []
|
||||
if self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
batch_size_capture_list = [1, 2, 4
|
||||
] + [i for i in range(8, 513, 8)]
|
||||
|
||||
self.compilation_config.init_with_cudagraph_sizes(
|
||||
batch_size_capture_list)
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
if self.cache_config is not None and \
|
||||
self.cache_config.cpu_offload_gb > 0 and \
|
||||
@ -2659,6 +2634,70 @@ class VllmConfig:
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
def _set_cudagraph_sizes(self):
|
||||
"""
|
||||
cudagraph batchsize padding logic:
|
||||
|
||||
`[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible
|
||||
batch sizes that cudagraph will capture.
|
||||
|
||||
Depending on the engine's configuration of `max_num_seqs`, the
|
||||
candidate batch sizes to capture cudagraph will shrink to the subset
|
||||
which just cover the range of `[1, max_num_seqs]`. In the common case,
|
||||
`max_num_seqs` is 256, and the cudagraph batch sizes will be
|
||||
`[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`.
|
||||
|
||||
However, if users specify the cudagraph capture sizes through
|
||||
compilation config, we will use the specified sizes instead.
|
||||
|
||||
In the end, `vllm_config.compilation_config.capture_sizes` will be the
|
||||
final sizes to capture cudagraph (in descending order).
|
||||
|
||||
During runtime, if batchsize is larger than
|
||||
`vllm_config.compilation_config.capture_sizes`,
|
||||
no cudagraph will be used.
|
||||
If the batch size is no larger than
|
||||
`vllm_config.compilation_config.capture_sizes`,
|
||||
we can quickly find the padded graph size for a given batch size by
|
||||
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
|
||||
"""
|
||||
|
||||
# calculate the default `batch_size_capture_list`
|
||||
if not envs.VLLM_USE_V1:
|
||||
batch_size_capture_list = []
|
||||
max_batchsize_to_capture = 0
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
|
||||
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
|
||||
# find the minimum size that is larger than max_num_seqs,
|
||||
# which then becomes the max_batchsize_to_capture
|
||||
larger_sizes = [
|
||||
x for x in possible_sizes
|
||||
if x >= self.scheduler_config.max_num_seqs
|
||||
]
|
||||
if larger_sizes:
|
||||
max_batchsize_to_capture = larger_sizes[0]
|
||||
else:
|
||||
max_batchsize_to_capture = possible_sizes[-1]
|
||||
|
||||
# filter out the sizes that are
|
||||
# larger than max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
size for size in possible_sizes
|
||||
if size <= max_batchsize_to_capture
|
||||
]
|
||||
else:
|
||||
batch_size_capture_list = []
|
||||
if self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
batch_size_capture_list = [1, 2, 4
|
||||
] + [i for i in range(8, 513, 8)]
|
||||
|
||||
self.compilation_config.init_with_cudagraph_sizes(
|
||||
batch_size_capture_list)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r},"
|
||||
|
||||
@ -7,7 +7,7 @@ from transformers import JambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -420,6 +420,17 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
else:
|
||||
self.max_batch_size = 8192 + 2
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
@ -433,15 +444,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
max_batch_size = (VllmConfig.get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
state_indices_tensor,
|
||||
|
||||
@ -6,7 +6,7 @@ from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -195,6 +195,17 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.backbone.make_empty_intermediate_tensors)
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
else:
|
||||
self.max_batch_size = 8192 + 2
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
@ -208,15 +219,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
max_batch_size = (VllmConfig.get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -459,7 +459,7 @@ class GPUModelRunner:
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self._get_padded_batch_size(
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
@ -641,10 +641,3 @@ class GPUModelRunner:
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
|
||||
def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
|
||||
# TODO: Optimize this?
|
||||
for size in self.cudagraph_batch_sizes:
|
||||
if batch_size <= size:
|
||||
return size
|
||||
return None
|
||||
|
||||
@ -464,7 +464,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
# We will be using CUDA graph replay for this decode.
|
||||
max_len_of_block_table = self.get_max_block_per_batch()
|
||||
batch_size = len(encoder_seq_lens)
|
||||
graph_batch_size = self.vllm_config.get_graph_batch_size(
|
||||
graph_batch_size = self.vllm_config.pad_for_cudagraph(
|
||||
batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
||||
|
||||
@ -802,7 +802,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
max_encoder_seq_len):
|
||||
return -1
|
||||
|
||||
graph_batch_size = VllmConfig.get_graph_batch_size(batch_size)
|
||||
graph_batch_size = self.runner.vllm_config.pad_for_cudagraph(
|
||||
batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
return graph_batch_size - batch_size
|
||||
|
||||
@ -1014,8 +1015,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
||||
self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
self.max_batchsize_to_capture = \
|
||||
self.vllm_config.compilation_config.max_capture_size
|
||||
|
||||
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
||||
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
@ -37,10 +37,6 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
||||
]
|
||||
|
||||
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user