Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-08 17:13:49 +00:00
parent dc1b6af362
commit bfa828f399
5 changed files with 163 additions and 128 deletions

View File

@ -48,21 +48,22 @@ class DPMetadata:
return num_tokens_tensor
@staticmethod
def should_ubatch_across_dp(should_ubatch: bool, dp_size: int, dp_rank: int) -> bool:
def should_ubatch_across_dp(should_ubatch: bool, dp_size: int,
dp_rank: int) -> bool:
should_ubatch_across_dp = [0] * dp_size
should_ubatch_across_dp[dp_rank] = 1 if should_ubatch else 0
should_ubatch_tensor = torch.tensor(should_ubatch_across_dp,
device="cpu",
dtype=torch.int32)
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group)
# This function uses the same ProcessGroup for all reduce as
# num_tokens_across_dp. If there's an incorrect ordering of ARs
# across DP ranks, this tensor can end up containing the number
# This function uses the same ProcessGroup for all reduce as
# num_tokens_across_dp. If there's an incorrect ordering of ARs
# across DP ranks, this tensor can end up containing the number
# of padded tokens for a DP rank.
assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1))
assert torch.all((should_ubatch_tensor == 0)
| (should_ubatch_tensor == 1))
result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
return result
@ -183,7 +184,7 @@ def set_forward_context(
forward_start_time = time.perf_counter()
forward_context = create_forward_context(attn_metadata, vllm_config,
virtual_engine, num_tokens,
virtual_engine, num_tokens,
num_tokens_across_dp,
skip_cuda_graphs)

View File

@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.v1.worker.ubatching import get_current_ubatch_context
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm.v1.worker.ubatching import get_current_ubatch_context
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
@ -1567,19 +1567,26 @@ class FusedMoE(torch.nn.Module):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
batched_hidden_states = self.batched_hidden_states[
batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[
batch_buffer_idx, :]
assert (batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states = batched_hidden_states[:
chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)

View File

@ -27,7 +27,7 @@ from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches)
make_local_attention_virtual_batches, slice_query_start_locs)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

View File

@ -60,6 +60,20 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def build_slice(
self,
req_slice: slice,
token_slice: slice,
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
"""
Should only be called on builders that support attention slicing
for micro batching
"""
raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
@ -105,6 +119,7 @@ def slice_query_start_locs(
return query_start_loc[req_slice.start: req_slice.stop + 1] -\
query_start_loc[req_slice.start]
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import dataclasses
import gc
import threading
import time
@ -29,8 +30,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model)
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context, DPMetadata,
from vllm.forward_context import (DPMetadata, create_forward_context,
get_forward_context,
override_forward_context,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
@ -48,8 +50,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -75,7 +77,6 @@ from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
import dataclasses
if TYPE_CHECKING:
import xgrammar as xgr
@ -99,6 +100,7 @@ PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
UbatchSlice: TypeAlias = tuple[slice, slice]
UBatchSlices: TypeAlias = list[UbatchSlice]
@dataclasses.dataclass
class UbatchMetadata:
context: UBatchContext
@ -577,10 +579,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata()
def _ubatch_split(
self,
max_num_scheduled_tokens: int,
scheduler_output: "SchedulerOutput"
) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
self, max_num_scheduled_tokens: int,
scheduler_output: "SchedulerOutput"
) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
# Don't bother with the should_ubatch handshaking unless microbatching
# is enabled
if not self.parallel_config.enable_microbatching:
@ -607,27 +608,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
b0_tokens_end < total_num_scheduled_tokens
ubatch_slices = [
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
(slice(b0_reqs_end, num_reqs),
slice(b0_tokens_end, total_num_scheduled_tokens)),
(slice(b0_reqs_end,
num_reqs), slice(b0_tokens_end,
total_num_scheduled_tokens)),
]
# Compute ubatch padding. This currently only accounts for DP padding
num_pad_tokens = 0
num_tokens_after_padding = None
ubatch_abort = False
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(
ubatch_slices)
if num_pad_tokens > 0:
# Check if the padding would result in an empty second ubatch.
# Check if the padding would result in an empty second ubatch.
# If so abort ubatching
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
else:
ubatch_abort = True
# Note that if we are attempting to ubatch by this point then we know that no
# DP ranks are doing dummy runs. Meaning, we don't need a second call to
# should_ubatch in _dummy_run
should_ubatch = self.should_ubatch(False if ubatch_abort else True)
# Note that if we are attempting to ubatch by this point then we know
# that no DP ranks are doing dummy runs. Meaning, we don't need a
# second call to should_ubatch in _dummy_run
should_ubatch = self.should_ubatch(not ubatch_abort)
if not should_ubatch:
return (None, 0, None)
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
@ -653,12 +656,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return cu_num_tokens, arange
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput"
) -> tuple[dict[str, Any], bool, torch.Tensor,
self, scheduler_output: "SchedulerOutput"
) -> tuple[PerLayerAttnMetadata, bool, torch.Tensor,
Optional[SpecDecodeMetadata], np.ndarray,
Optional[UBatchSlices],
int, Optional[torch.Tensor]]:
Optional[UBatchSlices], int, Optional[torch.Tensor]]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
@ -873,9 +874,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens, ubatch_slices, num_pad_tokens,
num_tokens_after_padding)
return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens, ubatch_slices,
num_pad_tokens, num_tokens_after_padding)
def _compute_cascade_attn_prefix_len(
self,
@ -1343,8 +1344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def get_padding(self,
num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]:
def get_padding(
self,
num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]:
num_tokens_padded = num_tokens_unpadded
@ -1352,7 +1354,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
num_tokens_padded = self.vllm_config.pad_for_cudagraph(
num_tokens_unpadded)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
@ -1364,12 +1367,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_padded)
num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(
num_tokens_padded)
return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding
def get_dp_padding_ubatch(self,
ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]:
def get_dp_padding_ubatch(
self,
ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size
if dp_size == 1:
@ -1379,54 +1384,63 @@ class GPUModelRunner(LoRAModelRunnerMixin):
first_ubatch_slice = ubatch_slices[0]
second_ubatch_slice = ubatch_slices[1]
first_ubatch_num_tokens = first_ubatch_slice[1].stop - first_ubatch_slice[1].start
second_ubatch_num_tokens = second_ubatch_slice[1].stop - second_ubatch_slice[1].start
# We don't support prefills yet so the two ubatches should only differ
first_ubatch_num_tokens = first_ubatch_slice[
1].stop - first_ubatch_slice[1].start
second_ubatch_num_tokens = second_ubatch_slice[
1].stop - second_ubatch_slice[1].start
# We don't support prefills yet so the two ubatches should only differ
# by at most one token
assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1
from vllm.utils import round_up
num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
num_tokens_padded = round_up(num_tokens_unpadded, 2)
num_tokens_per_ubatch = num_tokens_padded // 2
# Note that we compute the number of padded tokens per ubatch
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_per_ubatch)
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(
num_tokens_per_ubatch)
num_pad_tokens = ((num_pad_tokens + num_tokens_per_ubatch) * 2) - \
num_tokens_unpadded
return num_pad_tokens, num_tokens_after_padding
# This doesn't actually pad the ubatch slices. It just shifts the
# split point to the correct value so that padding can be applied
# This doesn't actually pad the ubatch slices. It just shifts the
# split point to the correct value so that padding can be applied
# to the second ubatch later. Should be called after ubatch
# slicing but before attention meta data creation
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
num_pad_tokens: int):
original_num_tokens = ubatch_slices[1][1].stop
assert num_pad_tokens < original_num_tokens
total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2
total_num_tokens_per_ubatch = (original_num_tokens +
num_pad_tokens) // 2
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
original_num_tokens)
ubatch_slices[0] = (padded_first_ubatch_slice, padded_first_ubatch_slice)
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice)
ubatch_slices[0] = (padded_first_ubatch_slice,
padded_first_ubatch_slice)
ubatch_slices[1] = (padded_second_ubatch_slice,
padded_second_ubatch_slice)
# This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
num_total_tokens: int):
# TODO Add asserts to make sure stage one ran
padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, num_total_tokens)
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice)
padded_second_ubatch_slice = slice(ubatch_slices[1][1].start,
num_total_tokens)
ubatch_slices[1] = (padded_second_ubatch_slice,
padded_second_ubatch_slice)
def should_ubatch(self, should_ubatch: bool) -> bool:
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, dp_rank)
return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size,
dp_rank)
def _get_dummy_model_inputs(self, num_tokens: int) -> tuple:
# Dummy batch. (hopefully we are the last one so we can just
@ -1455,8 +1469,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device))
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
slice(0, num_tokens), None, False)
slice(0, num_tokens), None, False)
return input_ids, positions, inputs_embeds, intermediate_tensors
@ -1506,58 +1519,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tokens_slice, intermediate_tensors, True)
return input_ids, positions, inputs_embeds, intermediate_tensors
def model_inputs(self, tokens_slice: slice, use_dummy_input: bool,
def model_inputs(self, tokens_slice: slice, use_dummy_input: bool,
scheduler_output: Optional["SchedulerOutput"]) -> tuple:
if use_dummy_input:
return self._get_dummy_model_inputs(tokens_slice.stop - tokens_slice.start)
return self._get_dummy_model_inputs(tokens_slice.stop -
tokens_slice.start)
else:
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
def _make_ubatch_metadata(self,
ubatch_slices,
attn_metadata,
compute_stream,
is_dummy_run,
num_tokens_across_dp,
skip_cuda_graphs,
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata,
compute_stream, is_dummy_run,
num_tokens_across_dp, skip_cuda_graphs,
scheduler_output) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
num_tokens = (tokens_slice.stop - tokens_slice.start)
forward_contexts.append(create_forward_context(
attn_metadata[i]
if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs))
forward_contexts.append(
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs))
ubatch_ctxs = make_ubatch_contexts(num_micro_batches=len(ubatch_slices),
compute_stream=compute_stream,
forward_contexts=forward_contexts,
device=self.device)
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
compute_stream=compute_stream,
forward_contexts=forward_contexts,
device=self.device)
ubatch_metadata: list[UbatchMetadata] = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \
self.model_inputs(tokens_slice, is_dummy_run, scheduler_output)
ubatch_metadata.append(UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
))
ubatch_metadata.append(
UbatchMetadata(context=ubatch_ctxs[i],
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors))
return ubatch_metadata
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
@ -1578,11 +1586,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
thread.start()
@ -1602,37 +1610,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False):
# run micro-batched
if ubatch_slices is not None:
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
compute_stream = torch.cuda.current_stream()
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
compute_stream=compute_stream,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
scheduler_output=scheduler_output
)
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
compute_stream=compute_stream,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
scheduler_output=scheduler_output)
return self._run_ubatches(ubatch_metadata, self.model)
# run normal batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
self.model_inputs(slice(0, num_scheduled_tokens), is_dummy_run, scheduler_output)
self.model_inputs(slice(0, num_scheduled_tokens),
is_dummy_run,
scheduler_output)
with set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs):
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs):
return self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
def _pool(
self,
hidden_states: torch.Tensor,
@ -1693,18 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output)
# num_scheduled_tokens_old = scheduler_output.total_num_scheduled_tokens
# num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_scheduled_tokens_old)
# Prepare the decoder inputs.
attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices, num_pad_tokens, num_tokens_after_padding = (
self._prepare_inputs(scheduler_output))
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices,
num_pad_tokens,
num_tokens_after_padding) = self._prepare_inputs(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = num_scheduled_tokens
if ubatch_slices and num_pad_tokens > 0:
num_input_tokens += num_pad_tokens
self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens)
elif ubatch_slices is None:
num_pad, num_tokens_after_padding = self.get_padding(num_input_tokens)
num_pad, num_tokens_after_padding = self.get_padding(
num_input_tokens)
num_input_tokens += num_pad
# Some attention backends only support CUDA Graphs in pure decode.
@ -1856,6 +1867,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Speculative decoding is not enabled.
spec_token_ids = None
else:
assert not ubatch_slices
assert isinstance(attn_metadata, dict)
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
@ -2301,7 +2314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# _dummy_run doesn't go through _prepare_inputs so
# _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP groups that may be
# attempting to microbatch here.
if self.parallel_config.enable_microbatching:
@ -2323,8 +2336,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
# We currently only microbatch if the number of tokens is
# over a certain threshold.
# We currently only microbatch if the number of tokens is
# over a certain threshold.
attn_metadata: Optional[dict[str, Any]] = None
if capture_attn_cudagraph:
attn_metadata = {}
@ -2354,7 +2367,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
outputs = self._run_model(