mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 02:09:07 +08:00
enable naive microbatching
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
8293182c8c
commit
37c9babaa0
@ -1,5 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
@ -9,13 +12,28 @@ prompts = [
|
|||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
# Create a sampling params object.
|
# Configure logging level for vllm (optional, uses VLLM_LOGGING_LEVEL env var).
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
logging_level = os.getenv("VLLM_LOGGING_LEVEL", "").upper()
|
||||||
|
if logging_level:
|
||||||
|
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO))
|
||||||
|
|
||||||
|
# Create a sampling params object, optionally limiting output tokens via MAX_TOKENS env var.
|
||||||
|
param_kwargs = {"temperature": 0.8, "top_p": 0.95}
|
||||||
|
max_tokens_env = os.getenv("MAX_TOKENS")
|
||||||
|
if max_tokens_env is not None:
|
||||||
|
try:
|
||||||
|
param_kwargs["max_tokens"] = int(max_tokens_env)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid MAX_TOKENS value: {max_tokens_env}")
|
||||||
|
sampling_params = SamplingParams(**param_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m")
|
llm = LLM(model="facebook/opt-125m",
|
||||||
|
enforce_eager=False,
|
||||||
|
compilation_config=2,
|
||||||
|
enable_microbatching=True,)
|
||||||
# Generate texts from the prompts.
|
# Generate texts from the prompts.
|
||||||
# The output is a list of RequestOutput objects
|
# The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
|
|||||||
@ -1741,6 +1741,9 @@ class ParallelConfig:
|
|||||||
rank: int = 0
|
rank: int = 0
|
||||||
"""Global rank in distributed setup."""
|
"""Global rank in distributed setup."""
|
||||||
|
|
||||||
|
enable_microbatching: bool = False
|
||||||
|
"""Enable microbatching for the model executor."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size_across_dp(self) -> int:
|
def world_size_across_dp(self) -> int:
|
||||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||||
@ -4313,6 +4316,11 @@ class VllmConfig:
|
|||||||
"cascade attention. Disabling cascade attention.")
|
"cascade attention. Disabling cascade attention.")
|
||||||
self.model_config.disable_cascade_attn = True
|
self.model_config.disable_cascade_attn = True
|
||||||
|
|
||||||
|
if self.parallel_config.enable_microbatching:
|
||||||
|
# Microbatching is not supported with piecewise compilation yet.
|
||||||
|
# More specifically piecewise cuda-graphs
|
||||||
|
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||||
|
|
||||||
if self.model_config and self.model_config.use_mla and \
|
if self.model_config and self.model_config.use_mla and \
|
||||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@ -291,6 +291,7 @@ class EngineArgs:
|
|||||||
data_parallel_address: Optional[str] = None
|
data_parallel_address: Optional[str] = None
|
||||||
data_parallel_rpc_port: Optional[int] = None
|
data_parallel_rpc_port: Optional[int] = None
|
||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
|
enable_microbatching: bool = ParallelConfig.enable_microbatching
|
||||||
max_parallel_loading_workers: Optional[
|
max_parallel_loading_workers: Optional[
|
||||||
int] = ParallelConfig.max_parallel_loading_workers
|
int] = ParallelConfig.max_parallel_loading_workers
|
||||||
block_size: Optional[BlockSize] = CacheConfig.block_size
|
block_size: Optional[BlockSize] = CacheConfig.block_size
|
||||||
@ -621,6 +622,9 @@ class EngineArgs:
|
|||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--enable-expert-parallel",
|
"--enable-expert-parallel",
|
||||||
**parallel_kwargs["enable_expert_parallel"])
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--enable-microbatching",
|
||||||
|
**parallel_kwargs["enable_microbatching"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--max-parallel-loading-workers",
|
"--max-parallel-loading-workers",
|
||||||
**parallel_kwargs["max_parallel_loading_workers"])
|
**parallel_kwargs["max_parallel_loading_workers"])
|
||||||
@ -1066,6 +1070,7 @@ class EngineArgs:
|
|||||||
data_parallel_master_ip=data_parallel_address,
|
data_parallel_master_ip=data_parallel_address,
|
||||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||||
enable_expert_parallel=self.enable_expert_parallel,
|
enable_expert_parallel=self.enable_expert_parallel,
|
||||||
|
enable_microbatching=self.enable_microbatching,
|
||||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||||
|
|||||||
@ -324,10 +324,12 @@ class FlashAttentionMetadataBuilder:
|
|||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def build_slice(self, max_query_len: int, common_prefix_len: int,
|
def build_slice(self, req_slice: slice,
|
||||||
|
token_slice: slice,
|
||||||
|
max_query_len: int,
|
||||||
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
req_slice: slice,
|
) -> FlashAttentionMetadata:
|
||||||
token_slice: slice) -> FlashAttentionMetadata:
|
|
||||||
num_reqs = req_slice.stop - req_slice.start
|
num_reqs = req_slice.stop - req_slice.start
|
||||||
num_tokens = token_slice.stop - token_slice.start
|
num_tokens = token_slice.stop - token_slice.start
|
||||||
|
|
||||||
@ -472,15 +474,15 @@ class FlashAttentionMetadataBuilder:
|
|||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata):
|
common_attn_metadata: CommonAttentionMetadata):
|
||||||
return self.build_slice(
|
return self.build_slice(
|
||||||
|
req_slice=slice(0, num_reqs),
|
||||||
|
token_slice=slice(0, num_actual_tokens),
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
req_slice=slice(0, num_reqs),
|
|
||||||
token_slice=slice(0, num_actual_tokens),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||||
return use_cascade_attention(*args, **kwargs)
|
return False #use_cascade_attention(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionImpl(AttentionImpl):
|
class FlashAttentionImpl(AttentionImpl):
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
||||||
@ -73,6 +74,7 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata]
|
|||||||
# list when ubatching is enabled
|
# list when ubatching is enabled
|
||||||
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||||
AttnMetadataDict]
|
AttnMetadataDict]
|
||||||
|
UBatchSlices: TypeAlias = Optional[list[tuple[slice, slice]]]
|
||||||
|
|
||||||
|
|
||||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||||
@ -493,13 +495,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if batch_changed or batch_reordered:
|
if batch_changed or batch_reordered:
|
||||||
self.input_batch.refresh_sampling_metadata()
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _ubatch_split(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
max_num_scheduled_tokens: int,
|
||||||
ubatch_slices: Optional[list[tuple[
|
scheduler_output: "SchedulerOutput"
|
||||||
slice, slice]]] = None, # req_slice, token_slice
|
) -> Optional[UBatchSlices]:
|
||||||
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
|
||||||
|
if self.parallel_config.enable_microbatching and max_num_scheduled_tokens == 1:
|
||||||
|
# For pure decode we can just create ubatchs by cutting the request
|
||||||
|
# in half
|
||||||
|
b0_reqs_end = num_reqs // 2
|
||||||
|
b0_tokens_end = total_num_scheduled_tokens // 2
|
||||||
|
return [
|
||||||
|
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||||
|
(slice(b0_reqs_end, num_reqs),
|
||||||
|
slice(b0_tokens_end, total_num_scheduled_tokens)),
|
||||||
|
]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _prepare_inputs(
|
||||||
|
self, scheduler_output: "SchedulerOutput"
|
||||||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||||
Optional[SpecDecodeMetadata]]:
|
Optional[SpecDecodeMetadata], Optional[UBatchSlices]]:
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@ -515,6 +534,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||||
max_num_scheduled_tokens = max(tokens)
|
max_num_scheduled_tokens = max(tokens)
|
||||||
|
|
||||||
|
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
|
||||||
|
max_num_scheduled_tokens, scheduler_output)
|
||||||
|
|
||||||
# Get request indices.
|
# Get request indices.
|
||||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||||
@ -650,11 +672,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
attn_metadata_i = (
|
attn_metadata_i = (
|
||||||
self.attn_metadata_builders[kv_cache_group_id].
|
self.attn_metadata_builders[kv_cache_group_id].
|
||||||
build_slice(
|
build_slice(
|
||||||
max(tokens[req_slice]),
|
|
||||||
common_prefix_len=common_prefix_len,
|
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
req_slice=req_slice,
|
req_slice=req_slice,
|
||||||
token_slice=token_slice,
|
token_slice=token_slice,
|
||||||
|
max_query_len=max(tokens[req_slice]),
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
))
|
))
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
assert type(attn_metadata) is list
|
assert type(attn_metadata) is list
|
||||||
@ -699,7 +721,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||||
|
|
||||||
return attn_metadata, logits_indices, spec_decode_metadata
|
return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_len(
|
def _compute_cascade_attn_prefix_len(
|
||||||
self,
|
self,
|
||||||
@ -1136,7 +1158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Use piecewise CUDA graphs.
|
# Use piecewise CUDA graphs.
|
||||||
# Add padding to the batch size.
|
# Add padding to the batch size.
|
||||||
tokens_slice = \
|
tokens_slice = \
|
||||||
slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens))
|
slice(tokens_slice.start, tokens_slice.start+
|
||||||
|
self.vllm_config.pad_for_cudagraph(num_tokens))
|
||||||
else:
|
else:
|
||||||
# Eager mode.
|
# Eager mode.
|
||||||
# Pad tokens to multiple of tensor_parallel_size when
|
# Pad tokens to multiple of tensor_parallel_size when
|
||||||
@ -1145,8 +1168,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.vllm_config.compilation_config.pass_config. \
|
if self.vllm_config.compilation_config.pass_config. \
|
||||||
enable_sequence_parallelism and tp_size > 1:
|
enable_sequence_parallelism and tp_size > 1:
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
tokens_slice = slice(tokens_slice.start,
|
tokens_slice = slice(
|
||||||
round_up(num_tokens, tp_size))
|
tokens_slice.start,
|
||||||
|
tokens_slice.start + round_up(num_tokens, tp_size))
|
||||||
|
|
||||||
|
# update num tokens for padding
|
||||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
@ -1197,7 +1223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||||
ubatch_slices: Optional[list[tuple[slice, slice]]] = None
|
|
||||||
|
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
@ -1208,7 +1233,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return self.kv_connector_no_forward(scheduler_output)
|
return self.kv_connector_no_forward(scheduler_output)
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices = (
|
||||||
self._prepare_inputs(scheduler_output))
|
self._prepare_inputs(scheduler_output))
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
|
||||||
@ -1217,6 +1242,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
|
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
|
model_outputs = []
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
self._get_model_inputs(tokens_slice, scheduler_output)
|
self._get_model_inputs(tokens_slice, scheduler_output)
|
||||||
@ -1224,14 +1250,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
with set_forward_context(attn_metadata[i],
|
with set_forward_context(attn_metadata[i],
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens):
|
num_tokens=num_input_token):
|
||||||
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# clone is important for eventually piecewise cuda-graphs
|
||||||
|
model_outputs.append(model_output.clone())
|
||||||
|
model_output = torch.cat(model_outputs, dim=0)
|
||||||
else:
|
else:
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
self._get_model_inputs(slice(0, num_scheduled_tokens),
|
self._get_model_inputs(slice(0, num_scheduled_tokens),
|
||||||
|
|||||||
@ -2123,6 +2123,8 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
|
print("=== CUDAGraphRunner forward ===")
|
||||||
|
|
||||||
# Copy the input tensors to the input buffers.
|
# Copy the input tensors to the input buffers.
|
||||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||||
if positions is not None:
|
if positions is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user