enable naive microbatching

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-19 21:38:16 +00:00
parent 8293182c8c
commit 37c9babaa0
6 changed files with 89 additions and 25 deletions

View File

@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from vllm import LLM, SamplingParams
# Sample prompts.
@ -9,13 +12,28 @@ prompts = [
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Configure logging level for vllm (optional, uses VLLM_LOGGING_LEVEL env var).
logging_level = os.getenv("VLLM_LOGGING_LEVEL", "").upper()
if logging_level:
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO))
# Create a sampling params object, optionally limiting output tokens via MAX_TOKENS env var.
param_kwargs = {"temperature": 0.8, "top_p": 0.95}
max_tokens_env = os.getenv("MAX_TOKENS")
if max_tokens_env is not None:
try:
param_kwargs["max_tokens"] = int(max_tokens_env)
except ValueError:
raise ValueError(f"Invalid MAX_TOKENS value: {max_tokens_env}")
sampling_params = SamplingParams(**param_kwargs)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m",
enforce_eager=False,
compilation_config=2,
enable_microbatching=True,)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

View File

@ -1740,6 +1740,9 @@ class ParallelConfig:
rank: int = 0
"""Global rank in distributed setup."""
enable_microbatching: bool = False
"""Enable microbatching for the model executor."""
@property
def world_size_across_dp(self) -> int:
@ -4312,6 +4315,11 @@ class VllmConfig:
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
if self.parallel_config.enable_microbatching:
# Microbatching is not supported with piecewise compilation yet.
# More specifically piecewise cuda-graphs
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):

View File

@ -291,6 +291,7 @@ class EngineArgs:
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_microbatching: bool = ParallelConfig.enable_microbatching
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[BlockSize] = CacheConfig.block_size
@ -621,6 +622,9 @@ class EngineArgs:
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument(
"--enable-microbatching",
**parallel_kwargs["enable_microbatching"])
parallel_group.add_argument(
"--max-parallel-loading-workers",
**parallel_kwargs["max_parallel_loading_workers"])
@ -1066,6 +1070,7 @@ class EngineArgs:
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
enable_expert_parallel=self.enable_expert_parallel,
enable_microbatching=self.enable_microbatching,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight,

View File

@ -324,10 +324,12 @@ class FlashAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool:
return False
def build_slice(self, max_query_len: int, common_prefix_len: int,
def build_slice(self, req_slice: slice,
token_slice: slice,
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
req_slice: slice,
token_slice: slice) -> FlashAttentionMetadata:
) -> FlashAttentionMetadata:
num_reqs = req_slice.stop - req_slice.start
num_tokens = token_slice.stop - token_slice.start
@ -472,15 +474,15 @@ class FlashAttentionMetadataBuilder:
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
return self.build_slice(
req_slice=slice(0, num_reqs),
token_slice=slice(0, num_actual_tokens),
max_query_len=max_query_len,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
req_slice=slice(0, num_reqs),
token_slice=slice(0, num_actual_tokens),
)
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)
return False #use_cascade_attention(*args, **kwargs)
class FlashAttentionImpl(AttentionImpl):

View File

@ -2,6 +2,7 @@
import copy
import gc
import os
import time
import weakref
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
@ -73,6 +74,7 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata]
# list when ubatching is enabled
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
AttnMetadataDict]
UBatchSlices: TypeAlias = Optional[list[tuple[slice, slice]]]
class GPUModelRunner(LoRAModelRunnerMixin):
@ -493,13 +495,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if batch_changed or batch_reordered:
self.input_batch.refresh_sampling_metadata()
def _prepare_inputs(
def _ubatch_split(
self,
scheduler_output: "SchedulerOutput",
ubatch_slices: Optional[list[tuple[
slice, slice]]] = None, # req_slice, token_slice
max_num_scheduled_tokens: int,
scheduler_output: "SchedulerOutput"
) -> Optional[UBatchSlices]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_reqs = self.input_batch.num_reqs
if self.parallel_config.enable_microbatching and max_num_scheduled_tokens == 1:
# For pure decode we can just create ubatchs by cutting the request
# in half
b0_reqs_end = num_reqs // 2
b0_tokens_end = total_num_scheduled_tokens // 2
return [
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
(slice(b0_reqs_end, num_reqs),
slice(b0_tokens_end, total_num_scheduled_tokens)),
]
return None
def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
Optional[SpecDecodeMetadata]]:
Optional[SpecDecodeMetadata], Optional[UBatchSlices]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@ -515,6 +534,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
max_num_scheduled_tokens, scheduler_output)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
@ -650,11 +672,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
build_slice(
max(tokens[req_slice]),
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
req_slice=req_slice,
token_slice=token_slice,
max_query_len=max(tokens[req_slice]),
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
@ -699,7 +721,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata
return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices
def _compute_cascade_attn_prefix_len(
self,
@ -1136,7 +1158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
tokens_slice = \
slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens))
slice(tokens_slice.start, tokens_slice.start+
self.vllm_config.pad_for_cudagraph(num_tokens))
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
@ -1145,8 +1168,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up
tokens_slice = slice(tokens_slice.start,
round_up(num_tokens, tp_size))
tokens_slice = slice(
tokens_slice.start,
tokens_slice.start + round_up(num_tokens, tp_size))
# update num tokens for padding
num_tokens = tokens_slice.stop - tokens_slice.start
# _prepare_inputs may reorder the batch, so we must gather multi
@ -1197,7 +1223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
ubatch_slices: Optional[list[tuple[slice, slice]]] = None
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
@ -1208,7 +1233,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -1217,6 +1242,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.maybe_setup_kv_connector(scheduler_output)
if ubatch_slices is not None:
model_outputs = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \
self._get_model_inputs(tokens_slice, scheduler_output)
@ -1224,14 +1250,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with set_forward_context(attn_metadata[i],
self.vllm_config,
num_tokens=num_input_tokens):
num_tokens=num_input_token):
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
# clone is important for eventually piecewise cuda-graphs
model_outputs.append(model_output.clone())
model_output = torch.cat(model_outputs, dim=0)
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
self._get_model_inputs(slice(0, num_scheduled_tokens),

View File

@ -2122,6 +2122,8 @@ class CUDAGraphRunner(nn.Module):
**kwargs,
) -> torch.Tensor:
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
print("=== CUDAGraphRunner forward ===")
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)