Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-08 17:13:49 +00:00
parent dc1b6af362
commit bfa828f399
5 changed files with 163 additions and 128 deletions

View File

@ -48,21 +48,22 @@ class DPMetadata:
return num_tokens_tensor return num_tokens_tensor
@staticmethod @staticmethod
def should_ubatch_across_dp(should_ubatch: bool, dp_size: int, dp_rank: int) -> bool: def should_ubatch_across_dp(should_ubatch: bool, dp_size: int,
dp_rank: int) -> bool:
should_ubatch_across_dp = [0] * dp_size should_ubatch_across_dp = [0] * dp_size
should_ubatch_across_dp[dp_rank] = 1 if should_ubatch else 0 should_ubatch_across_dp[dp_rank] = 1 if should_ubatch else 0
should_ubatch_tensor = torch.tensor(should_ubatch_across_dp, should_ubatch_tensor = torch.tensor(should_ubatch_across_dp,
device="cpu", device="cpu",
dtype=torch.int32) dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group) dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group)
# This function uses the same ProcessGroup for all reduce as # This function uses the same ProcessGroup for all reduce as
# num_tokens_across_dp. If there's an incorrect ordering of ARs # num_tokens_across_dp. If there's an incorrect ordering of ARs
# across DP ranks, this tensor can end up containing the number # across DP ranks, this tensor can end up containing the number
# of padded tokens for a DP rank. # of padded tokens for a DP rank.
assert torch.all((should_ubatch_tensor == 0)
assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1)) | (should_ubatch_tensor == 1))
result: bool = bool(torch.all(should_ubatch_tensor == 1).item()) result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
return result return result
@ -183,7 +184,7 @@ def set_forward_context(
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
forward_context = create_forward_context(attn_metadata, vllm_config, forward_context = create_forward_context(attn_metadata, vllm_config,
virtual_engine, num_tokens, virtual_engine, num_tokens,
num_tokens_across_dp, num_tokens_across_dp,
skip_cuda_graphs) skip_cuda_graphs)

View File

@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.v1.worker.ubatching import get_current_ubatch_context
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm.v1.worker.ubatching import get_current_ubatch_context
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
@ -1567,19 +1567,26 @@ class FusedMoE(torch.nn.Module):
chunk_size = chunk_end - chunk_start chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :] hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :]
ubatch_ctx = get_current_ubatch_context() ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :] assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
batched_hidden_states = self.batched_hidden_states[
batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[
batch_buffer_idx, :]
assert (batched_hidden_states.size(0) # type: ignore assert (batched_hidden_states.size(0) # type: ignore
>= chunk_size) >= chunk_size)
assert (batched_router_logits.size(0) # type: ignore assert (batched_router_logits.size(0) # type: ignore
>= chunk_size) >= chunk_size)
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore staged_hidden_states = batched_hidden_states[:
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)

View File

@ -27,7 +27,7 @@ from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches) make_local_attention_virtual_batches, slice_query_start_locs)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable

View File

@ -60,6 +60,20 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
""" """
raise NotImplementedError raise NotImplementedError
def build_slice(
self,
req_slice: slice,
token_slice: slice,
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
"""
Should only be called on builders that support attention slicing
for micro batching
"""
raise NotImplementedError
def can_run_in_cudagraph( def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool: self, common_attn_metadata: CommonAttentionMetadata) -> bool:
""" """
@ -105,6 +119,7 @@ def slice_query_start_locs(
return query_start_loc[req_slice.start: req_slice.stop + 1] -\ return query_start_loc[req_slice.start: req_slice.stop + 1] -\
query_start_loc[req_slice.start] query_start_loc[req_slice.start]
def validate_kv_sharing_target(current_layer_name, target_layer_name, def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context): static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} " error_msg = (f"Specified KV sharing target layer for {current_layer_name} "

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
import dataclasses
import gc import gc
import threading import threading
import time import time
@ -29,8 +30,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model) prepare_communication_buffer_for_model)
from vllm.forward_context import (create_forward_context, get_forward_context, from vllm.forward_context import (DPMetadata, create_forward_context,
override_forward_context, DPMetadata, get_forward_context,
override_forward_context,
set_forward_context) set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
@ -48,8 +50,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, get_dtype_size, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up) is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -75,7 +77,6 @@ from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
import dataclasses
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
@ -99,6 +100,7 @@ PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
UbatchSlice: TypeAlias = tuple[slice, slice] UbatchSlice: TypeAlias = tuple[slice, slice]
UBatchSlices: TypeAlias = list[UbatchSlice] UBatchSlices: TypeAlias = list[UbatchSlice]
@dataclasses.dataclass @dataclasses.dataclass
class UbatchMetadata: class UbatchMetadata:
context: UBatchContext context: UBatchContext
@ -577,10 +579,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata() self.input_batch.refresh_sampling_metadata()
def _ubatch_split( def _ubatch_split(
self, self, max_num_scheduled_tokens: int,
max_num_scheduled_tokens: int, scheduler_output: "SchedulerOutput"
scheduler_output: "SchedulerOutput" ) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
# Don't bother with the should_ubatch handshaking unless microbatching # Don't bother with the should_ubatch handshaking unless microbatching
# is enabled # is enabled
if not self.parallel_config.enable_microbatching: if not self.parallel_config.enable_microbatching:
@ -607,27 +608,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
b0_tokens_end < total_num_scheduled_tokens b0_tokens_end < total_num_scheduled_tokens
ubatch_slices = [ ubatch_slices = [
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)), (slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
(slice(b0_reqs_end, num_reqs), (slice(b0_reqs_end,
slice(b0_tokens_end, total_num_scheduled_tokens)), num_reqs), slice(b0_tokens_end,
total_num_scheduled_tokens)),
] ]
# Compute ubatch padding. This currently only accounts for DP padding # Compute ubatch padding. This currently only accounts for DP padding
num_pad_tokens = 0 num_pad_tokens = 0
num_tokens_after_padding = None num_tokens_after_padding = None
ubatch_abort = False ubatch_abort = False
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(
ubatch_slices)
if num_pad_tokens > 0: if num_pad_tokens > 0:
# Check if the padding would result in an empty second ubatch. # Check if the padding would result in an empty second ubatch.
# If so abort ubatching # If so abort ubatching
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens: if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens) self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
else: else:
ubatch_abort = True ubatch_abort = True
# Note that if we are attempting to ubatch by this point then we know that no # Note that if we are attempting to ubatch by this point then we know
# DP ranks are doing dummy runs. Meaning, we don't need a second call to # that no DP ranks are doing dummy runs. Meaning, we don't need a
# should_ubatch in _dummy_run # second call to should_ubatch in _dummy_run
should_ubatch = self.should_ubatch(False if ubatch_abort else True) should_ubatch = self.should_ubatch(not ubatch_abort)
if not should_ubatch: if not should_ubatch:
return (None, 0, None) return (None, 0, None)
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding) return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
@ -653,12 +656,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return cu_num_tokens, arange return cu_num_tokens, arange
def _prepare_inputs( def _prepare_inputs(
self, self, scheduler_output: "SchedulerOutput"
scheduler_output: "SchedulerOutput" ) -> tuple[PerLayerAttnMetadata, bool, torch.Tensor,
) -> tuple[dict[str, Any], bool, torch.Tensor,
Optional[SpecDecodeMetadata], np.ndarray, Optional[SpecDecodeMetadata], np.ndarray,
Optional[UBatchSlices], Optional[UBatchSlices], int, Optional[torch.Tensor]]:
int, Optional[torch.Tensor]]:
""" """
:return: tuple[ :return: tuple[
attn_metadata: layer-to-attention_metadata mapping, attn_metadata: layer-to-attention_metadata mapping,
@ -873,9 +874,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices, return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens, ubatch_slices, num_pad_tokens, spec_decode_metadata, num_scheduled_tokens, ubatch_slices,
num_tokens_after_padding) num_pad_tokens, num_tokens_after_padding)
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
@ -1343,8 +1344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32) dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def get_padding(self, def get_padding(
num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]: self,
num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]:
num_tokens_padded = num_tokens_unpadded num_tokens_padded = num_tokens_unpadded
@ -1352,7 +1354,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
# Add padding to the batch size. # Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) num_tokens_padded = self.vllm_config.pad_for_cudagraph(
num_tokens_unpadded)
else: else:
# Eager mode. # Eager mode.
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
@ -1364,12 +1367,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
num_pad_tokens = num_tokens_padded - num_tokens_unpadded num_pad_tokens = num_tokens_padded - num_tokens_unpadded
num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_padded) num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(
num_tokens_padded)
return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding
def get_dp_padding_ubatch(self, def get_dp_padding_ubatch(
ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]: self,
ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
if dp_size == 1: if dp_size == 1:
@ -1379,54 +1384,63 @@ class GPUModelRunner(LoRAModelRunnerMixin):
first_ubatch_slice = ubatch_slices[0] first_ubatch_slice = ubatch_slices[0]
second_ubatch_slice = ubatch_slices[1] second_ubatch_slice = ubatch_slices[1]
first_ubatch_num_tokens = first_ubatch_slice[1].stop - first_ubatch_slice[1].start first_ubatch_num_tokens = first_ubatch_slice[
second_ubatch_num_tokens = second_ubatch_slice[1].stop - second_ubatch_slice[1].start 1].stop - first_ubatch_slice[1].start
# We don't support prefills yet so the two ubatches should only differ second_ubatch_num_tokens = second_ubatch_slice[
1].stop - second_ubatch_slice[1].start
# We don't support prefills yet so the two ubatches should only differ
# by at most one token # by at most one token
assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1 assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1
from vllm.utils import round_up num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
num_tokens_padded = round_up(num_tokens_unpadded, 2) num_tokens_padded = round_up(num_tokens_unpadded, 2)
num_tokens_per_ubatch = num_tokens_padded // 2 num_tokens_per_ubatch = num_tokens_padded // 2
# Note that we compute the number of padded tokens per ubatch # Note that we compute the number of padded tokens per ubatch
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_per_ubatch) num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(
num_tokens_per_ubatch)
num_pad_tokens = ((num_pad_tokens + num_tokens_per_ubatch) * 2) - \ num_pad_tokens = ((num_pad_tokens + num_tokens_per_ubatch) * 2) - \
num_tokens_unpadded num_tokens_unpadded
return num_pad_tokens, num_tokens_after_padding return num_pad_tokens, num_tokens_after_padding
# This doesn't actually pad the ubatch slices. It just shifts the # This doesn't actually pad the ubatch slices. It just shifts the
# split point to the correct value so that padding can be applied # split point to the correct value so that padding can be applied
# to the second ubatch later. Should be called after ubatch # to the second ubatch later. Should be called after ubatch
# slicing but before attention meta data creation # slicing but before attention meta data creation
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices, def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
num_pad_tokens: int): num_pad_tokens: int):
original_num_tokens = ubatch_slices[1][1].stop original_num_tokens = ubatch_slices[1][1].stop
assert num_pad_tokens < original_num_tokens assert num_pad_tokens < original_num_tokens
total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2 total_num_tokens_per_ubatch = (original_num_tokens +
num_pad_tokens) // 2
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens) padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
original_num_tokens)
ubatch_slices[0] = (padded_first_ubatch_slice, padded_first_ubatch_slice) ubatch_slices[0] = (padded_first_ubatch_slice,
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice) padded_first_ubatch_slice)
ubatch_slices[1] = (padded_second_ubatch_slice,
padded_second_ubatch_slice)
# This is where the second ubatch is adjusted to account for the padding. # This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just pads # Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens # the second ubatch slice out to the total number of tokens
# (num_tokens + padding) # (num_tokens + padding)
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int): def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
num_total_tokens: int):
# TODO Add asserts to make sure stage one ran # TODO Add asserts to make sure stage one ran
padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, num_total_tokens) padded_second_ubatch_slice = slice(ubatch_slices[1][1].start,
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice) num_total_tokens)
ubatch_slices[1] = (padded_second_ubatch_slice,
padded_second_ubatch_slice)
def should_ubatch(self, should_ubatch: bool) -> bool: def should_ubatch(self, should_ubatch: bool) -> bool:
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank dp_rank = self.vllm_config.parallel_config.data_parallel_rank
return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, dp_rank) return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size,
dp_rank)
def _get_dummy_model_inputs(self, num_tokens: int) -> tuple: def _get_dummy_model_inputs(self, num_tokens: int) -> tuple:
# Dummy batch. (hopefully we are the last one so we can just # Dummy batch. (hopefully we are the last one so we can just
@ -1455,8 +1469,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device)) device=self.device))
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
slice(0, num_tokens), None, False) slice(0, num_tokens), None, False)
return input_ids, positions, inputs_embeds, intermediate_tensors return input_ids, positions, inputs_embeds, intermediate_tensors
@ -1506,58 +1519,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tokens_slice, intermediate_tensors, True) tokens_slice, intermediate_tensors, True)
return input_ids, positions, inputs_embeds, intermediate_tensors return input_ids, positions, inputs_embeds, intermediate_tensors
def model_inputs(self, tokens_slice: slice, use_dummy_input: bool, def model_inputs(self, tokens_slice: slice, use_dummy_input: bool,
scheduler_output: Optional["SchedulerOutput"]) -> tuple: scheduler_output: Optional["SchedulerOutput"]) -> tuple:
if use_dummy_input: if use_dummy_input:
return self._get_dummy_model_inputs(tokens_slice.stop - tokens_slice.start) return self._get_dummy_model_inputs(tokens_slice.stop -
tokens_slice.start)
else: else:
assert scheduler_output is not None assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output) return self._get_model_inputs(tokens_slice, scheduler_output)
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata,
compute_stream, is_dummy_run,
def _make_ubatch_metadata(self, num_tokens_across_dp, skip_cuda_graphs,
ubatch_slices,
attn_metadata,
compute_stream,
is_dummy_run,
num_tokens_across_dp,
skip_cuda_graphs,
scheduler_output) -> list[UbatchMetadata]: scheduler_output) -> list[UbatchMetadata]:
# Create one forward context per ubatch # Create one forward context per ubatch
forward_contexts = [] forward_contexts = []
for i, (_, tokens_slice) in enumerate(ubatch_slices): for i, (_, tokens_slice) in enumerate(ubatch_slices):
num_tokens = (tokens_slice.stop - tokens_slice.start) num_tokens = (tokens_slice.stop - tokens_slice.start)
forward_contexts.append(create_forward_context( forward_contexts.append(
attn_metadata[i] create_forward_context(
if attn_metadata is not None else None, attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs)) skip_cuda_graphs=skip_cuda_graphs))
ubatch_ctxs = make_ubatch_contexts(num_micro_batches=len(ubatch_slices), ubatch_ctxs = make_ubatch_contexts(
compute_stream=compute_stream, num_micro_batches=len(ubatch_slices),
forward_contexts=forward_contexts, compute_stream=compute_stream,
device=self.device) forward_contexts=forward_contexts,
device=self.device)
ubatch_metadata: list[UbatchMetadata] = [] ubatch_metadata: list[UbatchMetadata] = []
for i, (_, tokens_slice) in enumerate(ubatch_slices): for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \ input_ids, positions, inputs_embeds, intermediate_tensors = \
self.model_inputs(tokens_slice, is_dummy_run, scheduler_output) self.model_inputs(tokens_slice, is_dummy_run, scheduler_output)
ubatch_metadata.append(UbatchMetadata( ubatch_metadata.append(
context=ubatch_ctxs[i], UbatchMetadata(context=ubatch_ctxs[i],
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors intermediate_tensors=intermediate_tensors))
))
return ubatch_metadata return ubatch_metadata
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode() @torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata): def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context: with ubatch_metadata.context:
@ -1578,11 +1586,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_threads = [] ubatch_threads = []
for metadata in ubatch_metadata: for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread, thread = threading.Thread(target=_ubatch_thread,
args=( args=(
results, results,
model, model,
metadata, metadata,
)) ))
ubatch_threads.append(thread) ubatch_threads.append(thread)
thread.start() thread.start()
@ -1602,37 +1610,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp: Optional[torch.Tensor] = None, num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False): skip_cuda_graphs: bool = False):
# run micro-batched # run micro-batched
if ubatch_slices is not None: if ubatch_slices is not None:
assert len(ubatch_slices) == 2, "Only two ubatches has been tested" assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
compute_stream = torch.cuda.current_stream() compute_stream = torch.cuda.current_stream()
ubatch_metadata = self._make_ubatch_metadata( ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
compute_stream=compute_stream, compute_stream=compute_stream,
is_dummy_run=is_dummy_run, is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs, skip_cuda_graphs=skip_cuda_graphs,
scheduler_output=scheduler_output scheduler_output=scheduler_output)
)
return self._run_ubatches(ubatch_metadata, self.model) return self._run_ubatches(ubatch_metadata, self.model)
# run normal batch # run normal batch
else: else:
input_ids, positions, inputs_embeds, intermediate_tensors = \ input_ids, positions, inputs_embeds, intermediate_tensors = \
self.model_inputs(slice(0, num_scheduled_tokens), is_dummy_run, scheduler_output) self.model_inputs(slice(0, num_scheduled_tokens),
is_dummy_run,
scheduler_output)
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1, num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs): skip_cuda_graphs=skip_cuda_graphs):
return self.model( return self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
def _pool( def _pool(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1693,18 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output) return self.kv_connector_no_forward(scheduler_output)
# num_scheduled_tokens_old = scheduler_output.total_num_scheduled_tokens
# num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_scheduled_tokens_old)
# Prepare the decoder inputs. # Prepare the decoder inputs.
attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices, num_pad_tokens, num_tokens_after_padding = ( (attn_metadata, attention_cuda_graphs, logits_indices,
self._prepare_inputs(scheduler_output)) spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices,
num_pad_tokens,
num_tokens_after_padding) = self._prepare_inputs(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
if ubatch_slices and num_pad_tokens > 0: if ubatch_slices and num_pad_tokens > 0:
num_input_tokens += num_pad_tokens num_input_tokens += num_pad_tokens
self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens) self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens)
elif ubatch_slices is None: elif ubatch_slices is None:
num_pad, num_tokens_after_padding = self.get_padding(num_input_tokens) num_pad, num_tokens_after_padding = self.get_padding(
num_input_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
# Some attention backends only support CUDA Graphs in pure decode. # Some attention backends only support CUDA Graphs in pure decode.
@ -1856,6 +1867,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
else: else:
assert not ubatch_slices
assert isinstance(attn_metadata, dict)
spec_token_ids = self.propose_draft_token_ids( spec_token_ids = self.propose_draft_token_ids(
scheduler_output, scheduler_output,
valid_sampled_token_ids, valid_sampled_token_ids,
@ -2301,7 +2314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
is_profile: bool = False, is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# _dummy_run doesn't go through _prepare_inputs so # _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP groups that may be # we synchronize with other DP groups that may be
# attempting to microbatch here. # attempting to microbatch here.
if self.parallel_config.enable_microbatching: if self.parallel_config.enable_microbatching:
@ -2323,8 +2336,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list, num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32) dtype=np.int32)
# We currently only microbatch if the number of tokens is # We currently only microbatch if the number of tokens is
# over a certain threshold. # over a certain threshold.
attn_metadata: Optional[dict[str, Any]] = None attn_metadata: Optional[dict[str, Any]] = None
if capture_attn_cudagraph: if capture_attn_cudagraph:
attn_metadata = {} attn_metadata = {}
@ -2354,7 +2367,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
outputs = self._run_model( outputs = self._run_model(