mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[torch.compile] remove compilation_context and simplify code (#10838)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
21fe7b481a
commit
dc5ce861bf
@ -7,7 +7,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||
@ -96,11 +96,10 @@ def test_simple_piecewise_compile():
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
|
||||
with set_compile_context([1, 2]):
|
||||
model(inputs)
|
||||
model(inputs)
|
||||
|
||||
model(torch.randn(2).cuda())
|
||||
model(torch.randn(1).cuda())
|
||||
model(torch.randn(2).cuda())
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
global global_counter
|
||||
|
||||
@ -13,7 +13,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
@ -256,6 +255,7 @@ def run_model(llama_config,
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
if split_attn:
|
||||
compilation_config.splitting_ops = ["silly.attention"]
|
||||
@ -273,10 +273,9 @@ def run_model(llama_config,
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
positions = torch.arange(B).cuda()
|
||||
|
||||
with set_compile_context([1, 2]):
|
||||
model(input_ids, positions)
|
||||
model(input_ids[:2], positions[:2])
|
||||
model(input_ids[:1], positions[:1])
|
||||
model(input_ids, positions)
|
||||
model(input_ids[:2], positions[:2])
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
input_ids[:2].zero_()
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
@ -379,10 +378,13 @@ def benchmark():
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, )
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
@ -396,17 +398,16 @@ def benchmark():
|
||||
|
||||
graphs = {}
|
||||
|
||||
with set_compile_context(cudagraph_sizes):
|
||||
model(input_ids, positions)
|
||||
for b in cudagraph_sizes[::-1]:
|
||||
if not piecewise:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=pool):
|
||||
output = model(input_ids[:b], positions[:b])
|
||||
graphs[b] = (graph, output)
|
||||
else:
|
||||
model(input_ids, positions)
|
||||
for b in cudagraph_sizes[::-1]:
|
||||
if not piecewise:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=pool):
|
||||
output = model(input_ids[:b], positions[:b])
|
||||
graphs[b] = (model, output)
|
||||
graphs[b] = (graph, output)
|
||||
else:
|
||||
output = model(input_ids[:b], positions[:b])
|
||||
graphs[b] = (model, output)
|
||||
for b in cudagraph_sizes:
|
||||
if piecewise:
|
||||
# noqa is for `Function definition does not bind loop variable`
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
|
||||
while len(example_prompts) == VllmConfig.get_graph_batch_size(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
|
||||
@ -5,8 +5,8 @@ Run `pytest tests/models/test_mamba.py`.
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
|
||||
while len(example_prompts) == VllmConfig.get_graph_batch_size(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
|
||||
@ -4,12 +4,12 @@ from typing import List
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
BATCH_SIZES = [1, 4, 16, 64, 256]
|
||||
|
||||
@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
|
||||
# With CUDA Graph capture and replay enabled, the decoder and encoder
|
||||
# input sequences will be padded. Create the expected padded tensors
|
||||
# accordingly.
|
||||
graph_batch_size = _get_graph_batch_size(expanded_batch_size)
|
||||
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
|
||||
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
|
||||
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
|
||||
padded_encoder_seq_lens = encoder_seq_lens + list(
|
||||
|
||||
@ -3,13 +3,14 @@ from typing import List
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills == 0
|
||||
|
||||
@ -242,10 +242,6 @@ class VllmBackend:
|
||||
assert not self._called, "VllmBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
# config is updated now, because only here can
|
||||
# we get the sizes to capture for cudagraph
|
||||
# from compilation context
|
||||
self.compilation_configs.init_during_runtime()
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_compile_context: Any = None
|
||||
|
||||
|
||||
def get_compile_context() -> Any:
|
||||
"""Get the current compile context."""
|
||||
return _compile_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_compile_context(context: Any):
|
||||
"""A context manager that stores the current compile context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _compile_context
|
||||
prev_context = _compile_context
|
||||
_compile_context = context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_compile_context = prev_context
|
||||
@ -2357,15 +2357,10 @@ class CompilationConfig(BaseModel):
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
return VllmBackend(self)
|
||||
|
||||
def init_during_runtime(self):
|
||||
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
|
||||
"""To complete the initialization of config,
|
||||
we need to know the compile context, which is only available
|
||||
during the first run of the model.
|
||||
"""
|
||||
from vllm.compilation.compile_context import get_compile_context
|
||||
context = get_compile_context()
|
||||
context = copy.deepcopy(context) if context is not None else []
|
||||
sizes_to_specialize: List[int] = context
|
||||
we need to know the cudagraph sizes."""
|
||||
|
||||
if self.cudagraph_capture_sizes is None:
|
||||
self.capture_sizes = sizes_to_specialize
|
||||
else:
|
||||
@ -2386,6 +2381,21 @@ class CompilationConfig(BaseModel):
|
||||
self.inductor_compile_sizes = []
|
||||
self.compile_sizes = self.inductor_compile_sizes
|
||||
|
||||
# sort to make sure cudagraph capture sizes are in descending order
|
||||
self.capture_sizes.sort(reverse=True)
|
||||
|
||||
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
# all the token sizes that **can** be captured by cudagraph.
|
||||
# they can be arbitrarily large.
|
||||
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
|
||||
# the actual sizes to capture will be determined by the model,
|
||||
# depending on the model's max_num_seqs.
|
||||
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmConfig:
|
||||
@ -2413,6 +2423,41 @@ class VllmConfig:
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
||||
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
||||
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
||||
"""
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
@staticmethod
|
||||
def get_max_graph_batch_size(max_num_seqs: int) -> int:
|
||||
"""
|
||||
max_num_seqs: Maximum number of sequences in a batch.
|
||||
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
|
||||
|
||||
pad the max_num_seqs if necessary by calling get_graph_batch_size,
|
||||
which will deal with some edge cases like 1, 2, 4.
|
||||
|
||||
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
|
||||
size. if not, it means the padded size is larger than the largest size
|
||||
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
|
||||
_BATCH_SIZES_TO_CAPTURE.
|
||||
"""
|
||||
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
|
||||
if padded_size in _BATCH_SIZES_TO_CAPTURE:
|
||||
return padded_size
|
||||
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
return _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
@ -2496,6 +2541,28 @@ class VllmConfig:
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
max_batchsize_to_capture = 0
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
max_batchsize_to_capture = \
|
||||
self.get_max_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
batch_size_capture_list = [
|
||||
size for size in _BATCH_SIZES_TO_CAPTURE
|
||||
if size <= max_batchsize_to_capture
|
||||
]
|
||||
else:
|
||||
batch_size_capture_list = []
|
||||
if self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
batch_size_capture_list = [1, 2, 4
|
||||
] + [i for i in range(8, 513, 8)]
|
||||
|
||||
self.compilation_config.init_with_cudagraph_sizes(
|
||||
batch_size_capture_list)
|
||||
|
||||
if self.cache_config is not None and \
|
||||
self.cache_config.cpu_offload_gb > 0 and \
|
||||
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
|
||||
|
||||
@ -7,7 +7,7 @@ from transformers import JambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -25,8 +25,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||
_get_graph_batch_size)
|
||||
|
||||
from .interfaces import HasInnerState, SupportsLoRA
|
||||
from .utils import maybe_prefix
|
||||
@ -404,7 +402,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
max_batch_size = (_get_graph_batch_size(
|
||||
max_batch_size = (VllmConfig.get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -23,8 +23,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||
_get_graph_batch_size)
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
@ -187,7 +185,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
max_batch_size = (_get_graph_batch_size(
|
||||
max_batch_size = (VllmConfig.get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -100,7 +99,11 @@ class GPUModelRunner:
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
|
||||
# The convention is different.
|
||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||
# The batch sizes in the config are in descending order.
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(self.vllm_config.compilation_config.capture_sizes))
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
@ -548,10 +551,9 @@ class GPUModelRunner:
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(self.num_attn_layers)
|
||||
]
|
||||
with set_compile_context(self.cudagraph_batch_sizes):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
logits = logits[:self.max_num_tokens]
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
|
||||
@ -25,8 +25,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
_get_graph_batch_size)
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict)
|
||||
@ -465,7 +464,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
# We will be using CUDA graph replay for this decode.
|
||||
max_len_of_block_table = self.get_max_block_per_batch()
|
||||
batch_size = len(encoder_seq_lens)
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
graph_batch_size = self.vllm_config.get_graph_batch_size(
|
||||
batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
||||
# extend the cross_block_tables and encoder_seq_lens to match
|
||||
|
||||
@ -18,7 +18,6 @@ import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionState
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
||||
@ -63,16 +62,7 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
# all the token sizes that **can** be captured by cudagraph.
|
||||
# they can be arbitrarily large.
|
||||
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
|
||||
# the actual sizes to capture will be determined by the model,
|
||||
# depending on the model's max_num_seqs.
|
||||
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
|
||||
]
|
||||
|
||||
_NUM_WARMUP_ITERS = 2
|
||||
|
||||
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
|
||||
@ -763,7 +753,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
max_decode_seq_len: int,
|
||||
max_encoder_seq_len: int = 0) -> bool:
|
||||
return (decode_only and not self.runner.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
|
||||
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
|
||||
and batch_size <= self.runner.max_batchsize_to_capture)
|
||||
@ -811,7 +800,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
max_encoder_seq_len):
|
||||
return -1
|
||||
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
graph_batch_size = VllmConfig.get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
return graph_batch_size - batch_size
|
||||
|
||||
@ -1023,7 +1012,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
||||
self.max_batchsize_to_capture = _get_max_graph_batch_size(
|
||||
self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
|
||||
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
||||
@ -1333,14 +1322,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
graph_batch_size = self.max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
if self.model_config.enforce_eager:
|
||||
batch_size_capture_list = []
|
||||
with set_compile_context(batch_size_capture_list):
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
@ -1459,18 +1441,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
graph_batch_size = self.max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
|
||||
with self.attn_state.graph_capture(
|
||||
max_batch_size), graph_capture() as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
self.parallel_config.pipeline_parallel_size):
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
for batch_size in \
|
||||
self.vllm_config.compilation_config.capture_sizes:
|
||||
attn_metadata = (
|
||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||
batch_size,
|
||||
@ -1993,37 +1971,3 @@ class CUDAGraphRunner(nn.Module):
|
||||
return self.output_buffers["hidden_states"]
|
||||
|
||||
return self.output_buffers
|
||||
|
||||
|
||||
def _get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
||||
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
||||
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
||||
"""
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
|
||||
def _get_max_graph_batch_size(max_num_seqs: int) -> int:
|
||||
"""
|
||||
max_num_seqs: Maximum number of sequences in a batch.
|
||||
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
|
||||
|
||||
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
|
||||
which will deal with some edge cases like 1, 2, 4.
|
||||
|
||||
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
|
||||
if not, it means the padded size is larger than the largest size in
|
||||
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
|
||||
"""
|
||||
padded_size = _get_graph_batch_size(max_num_seqs)
|
||||
if padded_size in _BATCH_SIZES_TO_CAPTURE:
|
||||
return padded_size
|
||||
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
return _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user