From 37c9babaa0628824595fdb8e7b13b998dee00f1d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 19 May 2025 21:38:16 +0000 Subject: [PATCH] enable naive microbatching Signed-off-by: Lucas Wilkinson --- examples/offline_inference/basic/basic.py | 24 +++++++-- vllm/config.py | 8 +++ vllm/engine/arg_utils.py | 5 ++ vllm/v1/attention/backends/flash_attn.py | 14 +++--- vllm/v1/worker/gpu_model_runner.py | 61 +++++++++++++++++------ vllm/worker/model_runner.py | 2 + 6 files changed, 89 insertions(+), 25 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index ae5ae7cb48346..bbab7d97ae11d 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +import logging +import os + from vllm import LLM, SamplingParams # Sample prompts. @@ -9,13 +12,28 @@ prompts = [ "The capital of France is", "The future of AI is", ] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +# Configure logging level for vllm (optional, uses VLLM_LOGGING_LEVEL env var). +logging_level = os.getenv("VLLM_LOGGING_LEVEL", "").upper() +if logging_level: + logging.basicConfig(level=getattr(logging, logging_level, logging.INFO)) + +# Create a sampling params object, optionally limiting output tokens via MAX_TOKENS env var. +param_kwargs = {"temperature": 0.8, "top_p": 0.95} +max_tokens_env = os.getenv("MAX_TOKENS") +if max_tokens_env is not None: + try: + param_kwargs["max_tokens"] = int(max_tokens_env) + except ValueError: + raise ValueError(f"Invalid MAX_TOKENS value: {max_tokens_env}") +sampling_params = SamplingParams(**param_kwargs) def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + llm = LLM(model="facebook/opt-125m", + enforce_eager=False, + compilation_config=2, + enable_microbatching=True,) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/vllm/config.py b/vllm/config.py index 5382e9a16829d..aaf419a61a2d4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1740,6 +1740,9 @@ class ParallelConfig: rank: int = 0 """Global rank in distributed setup.""" + + enable_microbatching: bool = False + """Enable microbatching for the model executor.""" @property def world_size_across_dp(self) -> int: @@ -4312,6 +4315,11 @@ class VllmConfig: "full_cuda_graph is not supported with " "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True + + if self.parallel_config.enable_microbatching: + # Microbatching is not supported with piecewise compilation yet. + # More specifically piecewise cuda-graphs + self.compilation_config.level = CompilationLevel.DYNAMO_ONCE if self.model_config and self.model_config.use_mla and \ not (current_platform.is_cuda() or current_platform.is_rocm()): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0c6b15b79da3..e1a9eb179e7ae 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -291,6 +291,7 @@ class EngineArgs: data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_microbatching: bool = ParallelConfig.enable_microbatching max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -621,6 +622,9 @@ class EngineArgs: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument( + "--enable-microbatching", + **parallel_kwargs["enable_microbatching"]) parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1066,6 +1070,7 @@ class EngineArgs: data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, + enable_microbatching=self.enable_microbatching, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 2c8636a5aa087..9d83e155864b9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -324,10 +324,12 @@ class FlashAttentionMetadataBuilder: scheduler_output: "SchedulerOutput") -> bool: return False - def build_slice(self, max_query_len: int, common_prefix_len: int, + def build_slice(self, req_slice: slice, + token_slice: slice, + max_query_len: int, + common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - req_slice: slice, - token_slice: slice) -> FlashAttentionMetadata: + ) -> FlashAttentionMetadata: num_reqs = req_slice.stop - req_slice.start num_tokens = token_slice.stop - token_slice.start @@ -472,15 +474,15 @@ class FlashAttentionMetadataBuilder: common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): return self.build_slice( + req_slice=slice(0, num_reqs), + token_slice=slice(0, num_actual_tokens), max_query_len=max_query_len, common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - req_slice=slice(0, num_reqs), - token_slice=slice(0, num_actual_tokens), ) def use_cascade_attention(self, *args, **kwargs) -> bool: - return use_cascade_attention(*args, **kwargs) + return False #use_cascade_attention(*args, **kwargs) class FlashAttentionImpl(AttentionImpl): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 34abd898d02bc..927968120a09e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,6 +2,7 @@ import copy import gc +import os import time import weakref from typing import TYPE_CHECKING, Optional, TypeAlias, Union @@ -73,6 +74,7 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] +UBatchSlices: TypeAlias = Optional[list[tuple[slice, slice]]] class GPUModelRunner(LoRAModelRunnerMixin): @@ -493,13 +495,30 @@ class GPUModelRunner(LoRAModelRunnerMixin): if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() - def _prepare_inputs( + def _ubatch_split( self, - scheduler_output: "SchedulerOutput", - ubatch_slices: Optional[list[tuple[ - slice, slice]]] = None, # req_slice, token_slice + max_num_scheduled_tokens: int, + scheduler_output: "SchedulerOutput" + ) -> Optional[UBatchSlices]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_reqs = self.input_batch.num_reqs + + if self.parallel_config.enable_microbatching and max_num_scheduled_tokens == 1: + # For pure decode we can just create ubatchs by cutting the request + # in half + b0_reqs_end = num_reqs // 2 + b0_tokens_end = total_num_scheduled_tokens // 2 + return [ + (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), + (slice(b0_reqs_end, num_reqs), + slice(b0_tokens_end, total_num_scheduled_tokens)), + ] + return None + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" ) -> tuple[PerLayerAttnMetadata, torch.Tensor, - Optional[SpecDecodeMetadata]]: + Optional[SpecDecodeMetadata], Optional[UBatchSlices]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -515,6 +534,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) + ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( + max_num_scheduled_tokens, scheduler_output) + # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], @@ -650,11 +672,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): attn_metadata_i = ( self.attn_metadata_builders[kv_cache_group_id]. build_slice( - max(tokens[req_slice]), - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, req_slice=req_slice, token_slice=token_slice, + max_query_len=max(tokens[req_slice]), + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, )) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list @@ -699,7 +721,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices, spec_decode_metadata + return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices def _compute_cascade_attn_prefix_len( self, @@ -1136,7 +1158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Use piecewise CUDA graphs. # Add padding to the batch size. tokens_slice = \ - slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens)) + slice(tokens_slice.start, tokens_slice.start+ + self.vllm_config.pad_for_cudagraph(num_tokens)) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when @@ -1145,8 +1168,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.vllm_config.compilation_config.pass_config. \ enable_sequence_parallelism and tp_size > 1: from vllm.utils import round_up - tokens_slice = slice(tokens_slice.start, - round_up(num_tokens, tp_size)) + tokens_slice = slice( + tokens_slice.start, + tokens_slice.start + round_up(num_tokens, tp_size)) + + # update num tokens for padding num_tokens = tokens_slice.stop - tokens_slice.start # _prepare_inputs may reorder the batch, so we must gather multi @@ -1197,7 +1223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, IntermediateTensors]: - ubatch_slices: Optional[list[tuple[slice, slice]]] = None self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: @@ -1208,7 +1233,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output) # Prepare the decoder inputs. - attn_metadata, logits_indices, spec_decode_metadata = ( + attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices = ( self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -1217,6 +1242,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) if ubatch_slices is not None: + model_outputs = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): input_ids, positions, inputs_embeds, intermediate_tensors = \ self._get_model_inputs(tokens_slice, scheduler_output) @@ -1224,14 +1250,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): with set_forward_context(attn_metadata[i], self.vllm_config, - num_tokens=num_input_tokens): - + num_tokens=num_input_token): model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + + # clone is important for eventually piecewise cuda-graphs + model_outputs.append(model_output.clone()) + model_output = torch.cat(model_outputs, dim=0) else: input_ids, positions, inputs_embeds, intermediate_tensors = \ self._get_model_inputs(slice(0, num_scheduled_tokens), diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 15f40bcef8969..6ecdfa6204f2a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2122,6 +2122,8 @@ class CUDAGraphRunner(nn.Module): **kwargs, ) -> torch.Tensor: attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + + print("=== CUDAGraphRunner forward ===") # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)