[torch.compile] remove compilation_context and simplify code (#10838)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-02 22:19:02 -08:00 committed by GitHub
parent 21fe7b481a
commit dc5ce861bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 128 additions and 143 deletions

View File

@ -7,7 +7,6 @@ import torch
from torch import nn from torch import nn
from torch.library import Library from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph=True, use_cudagraph=True,
splitting_ops=["silly.attention"], splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True, cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
)) ))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='') model = SillyModel(vllm_config=vllm_config, prefix='')
@ -96,7 +96,6 @@ def test_simple_piecewise_compile():
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
with set_compile_context([1, 2]):
model(inputs) model(inputs)
model(torch.randn(2).cuda()) model(torch.randn(2).cuda())

View File

@ -13,7 +13,6 @@ import torch
from torch import nn from torch import nn
from torch.library import Library from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
@ -256,6 +255,7 @@ def run_model(llama_config,
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
cudagraph_capture_sizes=[1, 2],
) )
if split_attn: if split_attn:
compilation_config.splitting_ops = ["silly.attention"] compilation_config.splitting_ops = ["silly.attention"]
@ -273,7 +273,6 @@ def run_model(llama_config,
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda() positions = torch.arange(B).cuda()
with set_compile_context([1, 2]):
model(input_ids, positions) model(input_ids, positions)
model(input_ids[:2], positions[:2]) model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1]) model(input_ids[:1], positions[:1])
@ -379,10 +378,13 @@ def benchmark():
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
splitting_ops=["silly.attention"], splitting_ops=["silly.attention"],
cudagraph_capture_sizes=cudagraph_sizes,
) )
else: else:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, ) level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=cudagraph_sizes,
)
vllm_config = VllmConfig(compilation_config=compilation_config) vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
@ -396,7 +398,6 @@ def benchmark():
graphs = {} graphs = {}
with set_compile_context(cudagraph_sizes):
model(input_ids, positions) model(input_ids, positions)
for b in cudagraph_sizes[::-1]: for b in cudagraph_sizes[::-1]:
if not piecewise: if not piecewise:

View File

@ -1,8 +1,8 @@
import pytest import pytest
from tests.utils import multi_gpu_test from tests.utils import multi_gpu_test
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured # This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because # batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible # tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0]) example_prompts.append(example_prompts[0])
try: try:

View File

@ -5,8 +5,8 @@ Run `pytest tests/models/test_mamba.py`.
import pytest import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured # This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because # batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible # tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0]) example_prompts.append(example_prompts[0])
try: try:

View File

@ -4,12 +4,12 @@ from typing import List
import pytest import pytest
import torch import torch
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size
BATCH_SIZES = [1, 4, 16, 64, 256] BATCH_SIZES = [1, 4, 16, 64, 256]
@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder # With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors # input sequences will be padded. Create the expected padded tensors
# accordingly. # accordingly.
graph_batch_size = _get_graph_batch_size(expanded_batch_size) graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list( padded_encoder_seq_lens = encoder_seq_lens + list(

View File

@ -3,13 +3,14 @@ from typing import List
import pytest import pytest
import torch import torch
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input.attn_metadata, model_input.attn_metadata.slot_mapping) model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens) assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.num_prefills == 0 assert attn_metadata.num_prefills == 0

View File

@ -242,10 +242,6 @@ class VllmBackend:
assert not self._called, "VllmBackend can only be called once" assert not self._called, "VllmBackend can only be called once"
self.graph = graph self.graph = graph
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs.init_during_runtime()
self.configure_post_pass() self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph( self.split_gm, self.piecewise_graphs = split_graph(

View File

@ -1,23 +0,0 @@
from contextlib import contextmanager
from typing import Any
_compile_context: Any = None
def get_compile_context() -> Any:
"""Get the current compile context."""
return _compile_context
@contextmanager
def set_compile_context(context: Any):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
"""
global _compile_context
prev_context = _compile_context
_compile_context = context
try:
yield
finally:
_compile_context = prev_context

View File

@ -2357,15 +2357,10 @@ class CompilationConfig(BaseModel):
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
return VllmBackend(self) return VllmBackend(self)
def init_during_runtime(self): def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
"""To complete the initialization of config, """To complete the initialization of config,
we need to know the compile context, which is only available we need to know the cudagraph sizes."""
during the first run of the model.
"""
from vllm.compilation.compile_context import get_compile_context
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
if self.cudagraph_capture_sizes is None: if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize self.capture_sizes = sizes_to_specialize
else: else:
@ -2386,6 +2381,21 @@ class CompilationConfig(BaseModel):
self.inductor_compile_sizes = [] self.inductor_compile_sizes = []
self.compile_sizes = self.inductor_compile_sizes self.compile_sizes = self.inductor_compile_sizes
# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)
_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
@dataclass @dataclass
class VllmConfig: class VllmConfig:
@ -2413,6 +2423,41 @@ class VllmConfig:
kv_transfer_config: KVTransferConfig = field(default=None, kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore init=True) # type: ignore
@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]
@staticmethod @staticmethod
def _get_quantization_config( def _get_quantization_config(
model_config: ModelConfig, model_config: ModelConfig,
@ -2496,6 +2541,28 @@ class VllmConfig:
self.compilation_config.pass_config.enable_reshape = False self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.level = CompilationLevel.PIECEWISE
if not envs.VLLM_USE_V1:
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
self.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
if self.cache_config is not None and \ if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \ self.cache_config.cpu_offload_gb > 0 and \
self.compilation_config.level != CompilationLevel.NO_COMPILATION: self.compilation_config.level != CompilationLevel.NO_COMPILATION:

View File

@ -7,7 +7,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -25,8 +25,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from .interfaces import HasInnerState, SupportsLoRA from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix from .utils import maybe_prefix
@ -404,7 +402,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if self.mamba_cache is None: if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size( max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2) else max(_BATCH_SIZES_TO_CAPTURE) + 2)

View File

@ -6,7 +6,7 @@ from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -23,8 +23,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from .utils import maybe_prefix from .utils import maybe_prefix
@ -187,7 +185,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if self.mamba_cache is None: if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size( max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2) else max(_BATCH_SIZES_TO_CAPTURE) + 2)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(

View File

@ -8,7 +8,6 @@ import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
@ -100,7 +99,11 @@ class GPUModelRunner:
== CompilationLevel.PIECEWISE == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager) and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] # The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.vllm_config.compilation_config.capture_sizes))
self.positions = torch.zeros(self.max_num_tokens, self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
@ -548,7 +551,6 @@ class GPUModelRunner:
torch.tensor([], dtype=torch.float32, device=self.device) torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers) for _ in range(self.num_attn_layers)
] ]
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape. # Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens, hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches) dummy_kv_caches)

View File

@ -25,8 +25,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder, ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata, ModelInputForGPUWithSamplingMetadata)
_get_graph_batch_size)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict) _add_sampling_metadata_broadcastable_dict)
@ -465,7 +464,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# We will be using CUDA graph replay for this decode. # We will be using CUDA graph replay for this decode.
max_len_of_block_table = self.get_max_block_per_batch() max_len_of_block_table = self.get_max_block_per_batch()
batch_size = len(encoder_seq_lens) batch_size = len(encoder_seq_lens)
graph_batch_size = _get_graph_batch_size(batch_size) graph_batch_size = self.vllm_config.get_graph_batch_size(
batch_size)
assert graph_batch_size >= batch_size assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size cuda_graph_pad_size = graph_batch_size - batch_size
# extend the cross_block_tables and encoder_seq_lens to match # extend the cross_block_tables and encoder_seq_lens to match

View File

@ -18,7 +18,6 @@ import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed import get_kv_transfer_group, get_pp_group
@ -63,16 +62,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
LORA_WARMUP_RANK = 8 LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
_NUM_WARMUP_ITERS = 2 _NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
@ -763,7 +753,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_decode_seq_len: int, max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> bool: max_encoder_seq_len: int = 0) -> bool:
return (decode_only and not self.runner.model_config.enforce_eager return (decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture and max_decode_seq_len <= self.runner.max_seq_len_to_capture
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
and batch_size <= self.runner.max_batchsize_to_capture) and batch_size <= self.runner.max_batchsize_to_capture)
@ -811,7 +800,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len): max_encoder_seq_len):
return -1 return -1
graph_batch_size = _get_graph_batch_size(batch_size) graph_batch_size = VllmConfig.get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size assert graph_batch_size >= batch_size
return graph_batch_size - batch_size return graph_batch_size - batch_size
@ -1023,7 +1012,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = _get_max_graph_batch_size( self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs) self.scheduler_config.max_num_seqs)
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
@ -1333,13 +1322,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
if self.model_config.enforce_eager:
batch_size_capture_list = []
with set_compile_context(batch_size_capture_list):
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
@ -1459,18 +1441,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
with self.attn_state.graph_capture( with self.attn_state.graph_capture(
max_batch_size), graph_capture() as graph_capture_context: max_batch_size), graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range( for virtual_engine in range(
self.parallel_config.pipeline_parallel_size): self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list): for batch_size in \
self.vllm_config.compilation_config.capture_sizes:
attn_metadata = ( attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch( self.attn_state.graph_capture_get_metadata_for_batch(
batch_size, batch_size,
@ -1993,37 +1971,3 @@ class CUDAGraphRunner(nn.Module):
return self.output_buffers["hidden_states"] return self.output_buffers["hidden_states"]
return self.output_buffers return self.output_buffers
def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size = _get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]