mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 20:37:05 +08:00
enable naive microbatching
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
8293182c8c
commit
37c9babaa0
@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
@ -9,13 +12,28 @@ prompts = [
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# Configure logging level for vllm (optional, uses VLLM_LOGGING_LEVEL env var).
|
||||
logging_level = os.getenv("VLLM_LOGGING_LEVEL", "").upper()
|
||||
if logging_level:
|
||||
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO))
|
||||
|
||||
# Create a sampling params object, optionally limiting output tokens via MAX_TOKENS env var.
|
||||
param_kwargs = {"temperature": 0.8, "top_p": 0.95}
|
||||
max_tokens_env = os.getenv("MAX_TOKENS")
|
||||
if max_tokens_env is not None:
|
||||
try:
|
||||
param_kwargs["max_tokens"] = int(max_tokens_env)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid MAX_TOKENS value: {max_tokens_env}")
|
||||
sampling_params = SamplingParams(**param_kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=False,
|
||||
compilation_config=2,
|
||||
enable_microbatching=True,)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
|
||||
@ -1740,6 +1740,9 @@ class ParallelConfig:
|
||||
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
enable_microbatching: bool = False
|
||||
"""Enable microbatching for the model executor."""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
@ -4312,6 +4315,11 @@ class VllmConfig:
|
||||
"full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.parallel_config.enable_microbatching:
|
||||
# Microbatching is not supported with piecewise compilation yet.
|
||||
# More specifically piecewise cuda-graphs
|
||||
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
if self.model_config and self.model_config.use_mla and \
|
||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
|
||||
@ -291,6 +291,7 @@ class EngineArgs:
|
||||
data_parallel_address: Optional[str] = None
|
||||
data_parallel_rpc_port: Optional[int] = None
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
enable_microbatching: bool = ParallelConfig.enable_microbatching
|
||||
max_parallel_loading_workers: Optional[
|
||||
int] = ParallelConfig.max_parallel_loading_workers
|
||||
block_size: Optional[BlockSize] = CacheConfig.block_size
|
||||
@ -621,6 +622,9 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel",
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
parallel_group.add_argument(
|
||||
"--enable-microbatching",
|
||||
**parallel_kwargs["enable_microbatching"])
|
||||
parallel_group.add_argument(
|
||||
"--max-parallel-loading-workers",
|
||||
**parallel_kwargs["max_parallel_loading_workers"])
|
||||
@ -1066,6 +1070,7 @@ class EngineArgs:
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_microbatching=self.enable_microbatching,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||
|
||||
@ -324,10 +324,12 @@ class FlashAttentionMetadataBuilder:
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build_slice(self, max_query_len: int, common_prefix_len: int,
|
||||
def build_slice(self, req_slice: slice,
|
||||
token_slice: slice,
|
||||
max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
req_slice: slice,
|
||||
token_slice: slice) -> FlashAttentionMetadata:
|
||||
) -> FlashAttentionMetadata:
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
|
||||
@ -472,15 +474,15 @@ class FlashAttentionMetadataBuilder:
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
return self.build_slice(
|
||||
req_slice=slice(0, num_reqs),
|
||||
token_slice=slice(0, num_actual_tokens),
|
||||
max_query_len=max_query_len,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
req_slice=slice(0, num_reqs),
|
||||
token_slice=slice(0, num_actual_tokens),
|
||||
)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
return False #use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
||||
@ -73,6 +74,7 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata]
|
||||
# list when ubatching is enabled
|
||||
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||
AttnMetadataDict]
|
||||
UBatchSlices: TypeAlias = Optional[list[tuple[slice, slice]]]
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
@ -493,13 +495,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if batch_changed or batch_reordered:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _prepare_inputs(
|
||||
def _ubatch_split(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
ubatch_slices: Optional[list[tuple[
|
||||
slice, slice]]] = None, # req_slice, token_slice
|
||||
max_num_scheduled_tokens: int,
|
||||
scheduler_output: "SchedulerOutput"
|
||||
) -> Optional[UBatchSlices]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
||||
if self.parallel_config.enable_microbatching and max_num_scheduled_tokens == 1:
|
||||
# For pure decode we can just create ubatchs by cutting the request
|
||||
# in half
|
||||
b0_reqs_end = num_reqs // 2
|
||||
b0_tokens_end = total_num_scheduled_tokens // 2
|
||||
return [
|
||||
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||
(slice(b0_reqs_end, num_reqs),
|
||||
slice(b0_tokens_end, total_num_scheduled_tokens)),
|
||||
]
|
||||
return None
|
||||
|
||||
def _prepare_inputs(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
Optional[SpecDecodeMetadata], Optional[UBatchSlices]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -515,6 +534,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
|
||||
max_num_scheduled_tokens, scheduler_output)
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
@ -650,11 +672,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].
|
||||
build_slice(
|
||||
max(tokens[req_slice]),
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
req_slice=req_slice,
|
||||
token_slice=token_slice,
|
||||
max_query_len=max(tokens[req_slice]),
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
@ -699,7 +721,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return attn_metadata, logits_indices, spec_decode_metadata
|
||||
return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -1136,7 +1158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
tokens_slice = \
|
||||
slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens))
|
||||
slice(tokens_slice.start, tokens_slice.start+
|
||||
self.vllm_config.pad_for_cudagraph(num_tokens))
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
@ -1145,8 +1168,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.vllm_config.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
from vllm.utils import round_up
|
||||
tokens_slice = slice(tokens_slice.start,
|
||||
round_up(num_tokens, tp_size))
|
||||
tokens_slice = slice(
|
||||
tokens_slice.start,
|
||||
tokens_slice.start + round_up(num_tokens, tp_size))
|
||||
|
||||
# update num tokens for padding
|
||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
@ -1197,7 +1223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
ubatch_slices: Optional[list[tuple[slice, slice]]] = None
|
||||
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
@ -1208,7 +1233,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
@ -1217,6 +1242,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
model_outputs = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
@ -1224,14 +1250,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
with set_forward_context(attn_metadata[i],
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
|
||||
num_tokens=num_input_token):
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
# clone is important for eventually piecewise cuda-graphs
|
||||
model_outputs.append(model_output.clone())
|
||||
model_output = torch.cat(model_outputs, dim=0)
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self._get_model_inputs(slice(0, num_scheduled_tokens),
|
||||
|
||||
@ -2122,6 +2122,8 @@ class CUDAGraphRunner(nn.Module):
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
print("=== CUDAGraphRunner forward ===")
|
||||
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user