diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index cae25ae9fa2c..057b04349e8b 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,7 +1,7 @@ import pytest from tests.utils import multi_gpu_test -from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.get_graph_batch_size( + vllm_config = EngineArgs(model=model).create_engine_config() + while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 35018c3c14de..06739e8f0225 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,7 +5,7 @@ Run `pytest tests/models/test_mamba.py`. import pytest from transformers import AutoModelForCausalLM, AutoTokenizer -from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.get_graph_batch_size( + vllm_config = EngineArgs(model=model).create_engine_config() + while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 5289c91f201c..a6b3cb5759f2 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -4,7 +4,6 @@ from typing import List import pytest import torch -from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata @@ -548,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size) + graph_batch_size = model_runner.vllm_config.pad_for_cudagraph( + expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 4055524f3e0c..aabe913c242e 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,7 +3,6 @@ from typing import List import pytest import torch -from vllm.config import VllmConfig from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs @@ -177,7 +176,8 @@ def test_prepare_decode_cuda_graph(batch_size): model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) - expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list)) + expected_bs = model_runner.vllm_config.pad_for_cudagraph( + len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.num_prefills == 0 diff --git a/vllm/config.py b/vllm/config.py index 08a7b607630a..12ed80c366e4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2354,6 +2354,12 @@ class CompilationConfig(BaseModel): # not configurable, computed after init compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr + max_capture_size: int = PrivateAttr + # optimization: + # Intuitively, bs_to_padded_graph_size should be Dict[int, int]. + # since we know all keys are in a range [0, max_capture_size], + # we can optimize it to List[int] for better lookup performance. + bs_to_padded_graph_size: List[int] = PrivateAttr # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr @@ -2365,6 +2371,19 @@ class CompilationConfig(BaseModel): # Map from layer name to the attention cls static_forward_context: Dict[str, Any] = PrivateAttr + def __repr__(self) -> str: + exclude = { + "static_forward_context", + "enabled_custom_ops", + "disabled_custom_ops", + "compilation_time", + "bs_to_padded_graph_size", + "pass_config", + } + return self.model_dump_json(exclude=exclude, exclude_unset=True) + + __str__ = __repr__ + @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" @@ -2450,18 +2469,22 @@ class CompilationConfig(BaseModel): # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) + self.max_capture_size = self.capture_sizes[ + 0] if self.capture_sizes else 0 - -_BATCH_SIZE_ALIGNMENT = 8 -# all the token sizes that **can** be captured by cudagraph. -# they can be arbitrarily large. -# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. -# the actual sizes to capture will be determined by the model, -# depending on the model's max_num_seqs. -# NOTE: get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) -] + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_capture_size + 1) + ] + for end, start in zip(self.capture_sizes, + self.capture_sizes[1:] + [0]): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size[bs] = start + else: + self.bs_to_padded_graph_size[bs] = end + self.bs_to_padded_graph_size[ + self.max_capture_size] = self.max_capture_size @dataclass @@ -2491,40 +2514,12 @@ class VllmConfig: init=True) # type: ignore instance_id: str = "" - @staticmethod - def get_graph_batch_size(batch_size: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... - """ - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - @staticmethod - def get_max_graph_batch_size(max_num_seqs: int) -> int: - """ - max_num_seqs: Maximum number of sequences in a batch. - _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. - - pad the max_num_seqs if necessary by calling get_graph_batch_size, - which will deal with some edge cases like 1, 2, 4. - - if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded - size. if not, it means the padded size is larger than the largest size - in _BATCH_SIZES_TO_CAPTURE, return the largest size in - _BATCH_SIZES_TO_CAPTURE. - """ - padded_size = VllmConfig.get_graph_batch_size(max_num_seqs) - if padded_size in _BATCH_SIZES_TO_CAPTURE: - return padded_size - assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] + def pad_for_cudagraph(self, batch_size: int) -> int: + # if batch_size > self.compilation_config.max_capture_size, + # it should raise an IndexError. + # the caller should make sure the batch_size is within the range, + # i.e., batch_size <= self.compilation_config.max_capture_size + return self.compilation_config.bs_to_padded_graph_size[batch_size] @staticmethod def _get_quantization_config( @@ -2618,27 +2613,7 @@ class VllmConfig: self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE - if not envs.VLLM_USE_V1: - max_batchsize_to_capture = 0 - if self.scheduler_config is not None and \ - self.model_config is not None and \ - not self.model_config.enforce_eager: - max_batchsize_to_capture = \ - self.get_max_graph_batch_size( - self.scheduler_config.max_num_seqs) - batch_size_capture_list = [ - size for size in _BATCH_SIZES_TO_CAPTURE - if size <= max_batchsize_to_capture - ] - else: - batch_size_capture_list = [] - if self.model_config is not None and \ - not self.model_config.enforce_eager: - batch_size_capture_list = [1, 2, 4 - ] + [i for i in range(8, 513, 8)] - - self.compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) + self._set_cudagraph_sizes() if self.cache_config is not None and \ self.cache_config.cpu_offload_gb > 0 and \ @@ -2659,6 +2634,70 @@ class VllmConfig: if not self.instance_id: self.instance_id = random_uuid()[:5] + def _set_cudagraph_sizes(self): + """ + cudagraph batchsize padding logic: + + `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible + batch sizes that cudagraph will capture. + + Depending on the engine's configuration of `max_num_seqs`, the + candidate batch sizes to capture cudagraph will shrink to the subset + which just cover the range of `[1, max_num_seqs]`. In the common case, + `max_num_seqs` is 256, and the cudagraph batch sizes will be + `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. + + However, if users specify the cudagraph capture sizes through + compilation config, we will use the specified sizes instead. + + In the end, `vllm_config.compilation_config.capture_sizes` will be the + final sizes to capture cudagraph (in descending order). + + During runtime, if batchsize is larger than + `vllm_config.compilation_config.capture_sizes`, + no cudagraph will be used. + If the batch size is no larger than + `vllm_config.compilation_config.capture_sizes`, + we can quickly find the padded graph size for a given batch size by + looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. + """ + + # calculate the default `batch_size_capture_list` + if not envs.VLLM_USE_V1: + batch_size_capture_list = [] + max_batchsize_to_capture = 0 + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + + possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + # find the minimum size that is larger than max_num_seqs, + # which then becomes the max_batchsize_to_capture + larger_sizes = [ + x for x in possible_sizes + if x >= self.scheduler_config.max_num_seqs + ] + if larger_sizes: + max_batchsize_to_capture = larger_sizes[0] + else: + max_batchsize_to_capture = possible_sizes[-1] + + # filter out the sizes that are + # larger than max_batchsize_to_capture + batch_size_capture_list = [ + size for size in possible_sizes + if size <= max_batchsize_to_capture + ] + else: + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + batch_size_capture_list = [1, 2, 4 + ] + [i for i in range(8, 513, 8)] + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + def __str__(self): return ( f"model={self.model_config.model!r}," diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6bb4c13ab35d..831db2ae52d7 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from transformers import JambaConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -420,6 +420,17 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + else: + self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -433,15 +444,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, - *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, state_indices_tensor, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1f5cd0271189..06c8d9723cd0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from torch import nn from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -195,6 +195,17 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + else: + self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) @@ -208,15 +219,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) - num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, - *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aa91255e68d4..f24942068d1f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple import numpy as np import torch @@ -459,7 +459,7 @@ class GPUModelRunner: and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_input_tokens = self._get_padded_batch_size( + num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) else: # Eager mode. @@ -641,10 +641,3 @@ class GPUModelRunner: torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, device=self.device)) - - def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: - # TODO: Optimize this? - for size in self.cudagraph_batch_sizes: - if batch_size <= size: - return size - return None diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5697fbbaa204..bff01320d792 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -464,7 +464,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): # We will be using CUDA graph replay for this decode. max_len_of_block_table = self.get_max_block_per_batch() batch_size = len(encoder_seq_lens) - graph_batch_size = self.vllm_config.get_graph_batch_size( + graph_batch_size = self.vllm_config.pad_for_cudagraph( batch_size) assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 26fd486130ce..6ff98a8f1bab 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -802,7 +802,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): max_encoder_seq_len): return -1 - graph_batch_size = VllmConfig.get_graph_batch_size(batch_size) + graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( + batch_size) assert graph_batch_size >= batch_size return graph_batch_size - batch_size @@ -1014,8 +1015,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size( - self.scheduler_config.max_num_seqs) + self.max_batchsize_to_capture = \ + self.vllm_config.compilation_config.max_capture_size self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index e6322e095bbb..9cf25387560d 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -37,10 +37,6 @@ if TYPE_CHECKING: logger = init_logger(__name__) _PAD_SLOT_ID = -1 -_BATCH_SIZE_ALIGNMENT = 8 -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) -] TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")