diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 76d4801e0c1bb..3c1b544f0efa5 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -48,21 +48,22 @@ class DPMetadata: return num_tokens_tensor @staticmethod - def should_ubatch_across_dp(should_ubatch: bool, dp_size: int, dp_rank: int) -> bool: + def should_ubatch_across_dp(should_ubatch: bool, dp_size: int, + dp_rank: int) -> bool: should_ubatch_across_dp = [0] * dp_size should_ubatch_across_dp[dp_rank] = 1 if should_ubatch else 0 should_ubatch_tensor = torch.tensor(should_ubatch_across_dp, - device="cpu", - dtype=torch.int32) + device="cpu", + dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group) - # This function uses the same ProcessGroup for all reduce as - # num_tokens_across_dp. If there's an incorrect ordering of ARs - # across DP ranks, this tensor can end up containing the number + # This function uses the same ProcessGroup for all reduce as + # num_tokens_across_dp. If there's an incorrect ordering of ARs + # across DP ranks, this tensor can end up containing the number # of padded tokens for a DP rank. - - assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1)) + assert torch.all((should_ubatch_tensor == 0) + | (should_ubatch_tensor == 1)) result: bool = bool(torch.all(should_ubatch_tensor == 1).item()) return result @@ -183,7 +184,7 @@ def set_forward_context( forward_start_time = time.perf_counter() forward_context = create_forward_context(attn_metadata, vllm_config, - virtual_engine, num_tokens, + virtual_engine, num_tokens, num_tokens_across_dp, skip_cuda_graphs) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5480eae0ac7ef..37347d727dd74 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.v1.worker.ubatching import get_current_ubatch_context from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from vllm.v1.worker.ubatching import get_current_ubatch_context if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -1567,19 +1567,26 @@ class FusedMoE(torch.nn.Module): chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - + ubatch_ctx = get_current_ubatch_context() ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id - batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :] - batched_router_logits = self.batched_router_logits[batch_buffer_idx, :] + + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + batched_hidden_states = self.batched_hidden_states[ + batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[ + batch_buffer_idx, :] assert (batched_hidden_states.size(0) # type: ignore >= chunk_size) assert (batched_router_logits.size(0) # type: ignore >= chunk_size) - staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore - staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore + staged_hidden_states = batched_hidden_states[: + chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[: + chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 36c79b2d21a77..8b804408ea41f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -27,7 +27,7 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - make_local_attention_virtual_batches) + make_local_attention_virtual_batches, slice_query_start_locs) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e738d11665683..542e8b5a65377 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -60,6 +60,20 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError + def build_slice( + self, + req_slice: slice, + token_slice: slice, + max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + ) -> M: + """ + Should only be called on builders that support attention slicing + for micro batching + """ + raise NotImplementedError + def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: """ @@ -105,6 +119,7 @@ def slice_query_start_locs( return query_start_loc[req_slice.start: req_slice.stop + 1] -\ query_start_loc[req_slice.start] + def validate_kv_sharing_target(current_layer_name, target_layer_name, static_forward_context): error_msg = (f"Specified KV sharing target layer for {current_layer_name} " diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 79f4300d4f566..94a623dc878b8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import dataclasses import gc import threading import time @@ -29,8 +30,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import (create_forward_context, get_forward_context, - override_forward_context, DPMetadata, +from vllm.forward_context import (DPMetadata, create_forward_context, + get_forward_context, + override_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 @@ -48,8 +50,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -75,7 +77,6 @@ from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) -import dataclasses if TYPE_CHECKING: import xgrammar as xgr @@ -99,6 +100,7 @@ PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], UbatchSlice: TypeAlias = tuple[slice, slice] UBatchSlices: TypeAlias = list[UbatchSlice] + @dataclasses.dataclass class UbatchMetadata: context: UBatchContext @@ -577,10 +579,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.input_batch.refresh_sampling_metadata() def _ubatch_split( - self, - max_num_scheduled_tokens: int, - scheduler_output: "SchedulerOutput" - ) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]: + self, max_num_scheduled_tokens: int, + scheduler_output: "SchedulerOutput" + ) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]: # Don't bother with the should_ubatch handshaking unless microbatching # is enabled if not self.parallel_config.enable_microbatching: @@ -607,27 +608,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): b0_tokens_end < total_num_scheduled_tokens ubatch_slices = [ (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), - (slice(b0_reqs_end, num_reqs), - slice(b0_tokens_end, total_num_scheduled_tokens)), + (slice(b0_reqs_end, + num_reqs), slice(b0_tokens_end, + total_num_scheduled_tokens)), ] # Compute ubatch padding. This currently only accounts for DP padding num_pad_tokens = 0 num_tokens_after_padding = None ubatch_abort = False - num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) + num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch( + ubatch_slices) if num_pad_tokens > 0: - # Check if the padding would result in an empty second ubatch. + # Check if the padding would result in an empty second ubatch. # If so abort ubatching if num_pad_tokens < scheduler_output.total_num_scheduled_tokens: self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens) else: ubatch_abort = True - - # Note that if we are attempting to ubatch by this point then we know that no - # DP ranks are doing dummy runs. Meaning, we don't need a second call to - # should_ubatch in _dummy_run - should_ubatch = self.should_ubatch(False if ubatch_abort else True) + + # Note that if we are attempting to ubatch by this point then we know + # that no DP ranks are doing dummy runs. Meaning, we don't need a + # second call to should_ubatch in _dummy_run + should_ubatch = self.should_ubatch(not ubatch_abort) if not should_ubatch: return (None, 0, None) return (ubatch_slices, num_pad_tokens, num_tokens_after_padding) @@ -653,12 +656,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): return cu_num_tokens, arange def _prepare_inputs( - self, - scheduler_output: "SchedulerOutput" - ) -> tuple[dict[str, Any], bool, torch.Tensor, + self, scheduler_output: "SchedulerOutput" + ) -> tuple[PerLayerAttnMetadata, bool, torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray, - Optional[UBatchSlices], - int, Optional[torch.Tensor]]: + Optional[UBatchSlices], int, Optional[torch.Tensor]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -873,9 +874,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens, ubatch_slices, num_pad_tokens, - num_tokens_after_padding) + return (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens, ubatch_slices, + num_pad_tokens, num_tokens_after_padding) def _compute_cascade_attn_prefix_len( self, @@ -1343,8 +1344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - def get_padding(self, - num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]: + def get_padding( + self, + num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]: num_tokens_padded = num_tokens_unpadded @@ -1352,7 +1354,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) + num_tokens_padded = self.vllm_config.pad_for_cudagraph( + num_tokens_unpadded) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when @@ -1364,12 +1367,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_pad_tokens = num_tokens_padded - num_tokens_unpadded - num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_padded) + num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding( + num_tokens_padded) return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding - def get_dp_padding_ubatch(self, - ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding_ubatch( + self, + ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size if dp_size == 1: @@ -1379,54 +1384,63 @@ class GPUModelRunner(LoRAModelRunnerMixin): first_ubatch_slice = ubatch_slices[0] second_ubatch_slice = ubatch_slices[1] - first_ubatch_num_tokens = first_ubatch_slice[1].stop - first_ubatch_slice[1].start - second_ubatch_num_tokens = second_ubatch_slice[1].stop - second_ubatch_slice[1].start - # We don't support prefills yet so the two ubatches should only differ + first_ubatch_num_tokens = first_ubatch_slice[ + 1].stop - first_ubatch_slice[1].start + second_ubatch_num_tokens = second_ubatch_slice[ + 1].stop - second_ubatch_slice[1].start + # We don't support prefills yet so the two ubatches should only differ # by at most one token assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1 - from vllm.utils import round_up - - num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens + num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens num_tokens_padded = round_up(num_tokens_unpadded, 2) num_tokens_per_ubatch = num_tokens_padded // 2 # Note that we compute the number of padded tokens per ubatch - num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_per_ubatch) - + num_pad_tokens, num_tokens_after_padding = self.get_dp_padding( + num_tokens_per_ubatch) + num_pad_tokens = ((num_pad_tokens + num_tokens_per_ubatch) * 2) - \ num_tokens_unpadded return num_pad_tokens, num_tokens_after_padding - # This doesn't actually pad the ubatch slices. It just shifts the - # split point to the correct value so that padding can be applied + # This doesn't actually pad the ubatch slices. It just shifts the + # split point to the correct value so that padding can be applied # to the second ubatch later. Should be called after ubatch # slicing but before attention meta data creation - def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices, + def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices, num_pad_tokens: int): original_num_tokens = ubatch_slices[1][1].stop assert num_pad_tokens < original_num_tokens - total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2 + total_num_tokens_per_ubatch = (original_num_tokens + + num_pad_tokens) // 2 padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) - padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens) + padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, + original_num_tokens) - ubatch_slices[0] = (padded_first_ubatch_slice, padded_first_ubatch_slice) - ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice) + ubatch_slices[0] = (padded_first_ubatch_slice, + padded_first_ubatch_slice) + ubatch_slices[1] = (padded_second_ubatch_slice, + padded_second_ubatch_slice) # This is where the second ubatch is adjusted to account for the padding. # Should be called after attention metadata creation. This just pads - # the second ubatch slice out to the total number of tokens + # the second ubatch slice out to the total number of tokens # (num_tokens + padding) - def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, + num_total_tokens: int): # TODO Add asserts to make sure stage one ran - padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, num_total_tokens) - ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice) + padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, + num_total_tokens) + ubatch_slices[1] = (padded_second_ubatch_slice, + padded_second_ubatch_slice) def should_ubatch(self, should_ubatch: bool) -> bool: dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank - return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, dp_rank) + return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, + dp_rank) def _get_dummy_model_inputs(self, num_tokens: int) -> tuple: # Dummy batch. (hopefully we are the last one so we can just @@ -1455,8 +1469,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): device=self.device)) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - slice(0, num_tokens), None, False) - + slice(0, num_tokens), None, False) return input_ids, positions, inputs_embeds, intermediate_tensors @@ -1506,58 +1519,53 @@ class GPUModelRunner(LoRAModelRunnerMixin): tokens_slice, intermediate_tensors, True) return input_ids, positions, inputs_embeds, intermediate_tensors - def model_inputs(self, tokens_slice: slice, use_dummy_input: bool, + def model_inputs(self, tokens_slice: slice, use_dummy_input: bool, scheduler_output: Optional["SchedulerOutput"]) -> tuple: if use_dummy_input: - return self._get_dummy_model_inputs(tokens_slice.stop - tokens_slice.start) + return self._get_dummy_model_inputs(tokens_slice.stop - + tokens_slice.start) else: assert scheduler_output is not None return self._get_model_inputs(tokens_slice, scheduler_output) - - - def _make_ubatch_metadata(self, - ubatch_slices, - attn_metadata, - compute_stream, - is_dummy_run, - num_tokens_across_dp, - skip_cuda_graphs, + def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, + compute_stream, is_dummy_run, + num_tokens_across_dp, skip_cuda_graphs, scheduler_output) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): num_tokens = (tokens_slice.stop - tokens_slice.start) - forward_contexts.append(create_forward_context( - attn_metadata[i] - if attn_metadata is not None else None, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs)) + forward_contexts.append( + create_forward_context( + attn_metadata[i] if attn_metadata is not None else None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs)) - ubatch_ctxs = make_ubatch_contexts(num_micro_batches=len(ubatch_slices), - compute_stream=compute_stream, - forward_contexts=forward_contexts, - device=self.device) + ubatch_ctxs = make_ubatch_contexts( + num_micro_batches=len(ubatch_slices), + compute_stream=compute_stream, + forward_contexts=forward_contexts, + device=self.device) ubatch_metadata: list[UbatchMetadata] = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): input_ids, positions, inputs_embeds, intermediate_tensors = \ self.model_inputs(tokens_slice, is_dummy_run, scheduler_output) - ubatch_metadata.append(UbatchMetadata( - context=ubatch_ctxs[i], - input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors - )) - + ubatch_metadata.append( + UbatchMetadata(context=ubatch_ctxs[i], + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors)) + return ubatch_metadata - def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + @torch.inference_mode() def _ubatch_thread(results, model, ubatch_metadata): with ubatch_metadata.context: @@ -1578,11 +1586,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ubatch_threads = [] for metadata in ubatch_metadata: thread = threading.Thread(target=_ubatch_thread, - args=( - results, - model, - metadata, - )) + args=( + results, + model, + metadata, + )) ubatch_threads.append(thread) thread.start() @@ -1602,37 +1610,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False): - # run micro-batched if ubatch_slices is not None: assert len(ubatch_slices) == 2, "Only two ubatches has been tested" compute_stream = torch.cuda.current_stream() ubatch_metadata = self._make_ubatch_metadata( - ubatch_slices=ubatch_slices, - attn_metadata=attn_metadata, - compute_stream=compute_stream, - is_dummy_run=is_dummy_run, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, - scheduler_output=scheduler_output - ) + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + compute_stream=compute_stream, + is_dummy_run=is_dummy_run, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + scheduler_output=scheduler_output) return self._run_ubatches(ubatch_metadata, self.model) # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ - self.model_inputs(slice(0, num_scheduled_tokens), is_dummy_run, scheduler_output) + self.model_inputs(slice(0, num_scheduled_tokens), + is_dummy_run, + scheduler_output) with set_forward_context(attn_metadata, - vllm_config=self.vllm_config, - num_tokens=num_scheduled_tokens or 1, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs): + vllm_config=self.vllm_config, + num_tokens=num_scheduled_tokens or 1, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs): return self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + def _pool( self, hidden_states: torch.Tensor, @@ -1693,18 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output) - # num_scheduled_tokens_old = scheduler_output.total_num_scheduled_tokens - # num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_scheduled_tokens_old) # Prepare the decoder inputs. - attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices, num_pad_tokens, num_tokens_after_padding = ( - self._prepare_inputs(scheduler_output)) + (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices, + num_pad_tokens, + num_tokens_after_padding) = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_input_tokens = num_scheduled_tokens if ubatch_slices and num_pad_tokens > 0: num_input_tokens += num_pad_tokens self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens) elif ubatch_slices is None: - num_pad, num_tokens_after_padding = self.get_padding(num_input_tokens) + num_pad, num_tokens_after_padding = self.get_padding( + num_input_tokens) num_input_tokens += num_pad # Some attention backends only support CUDA Graphs in pure decode. @@ -1856,6 +1867,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Speculative decoding is not enabled. spec_token_ids = None else: + assert not ubatch_slices + assert isinstance(attn_metadata, dict) spec_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, @@ -2301,7 +2314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - # _dummy_run doesn't go through _prepare_inputs so + # _dummy_run doesn't go through _prepare_inputs so # we synchronize with other DP groups that may be # attempting to microbatch here. if self.parallel_config.enable_microbatching: @@ -2323,8 +2336,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - # We currently only microbatch if the number of tokens is - # over a certain threshold. + # We currently only microbatch if the number of tokens is + # over a certain threshold. attn_metadata: Optional[dict[str, Any]] = None if capture_attn_cudagraph: attn_metadata = {} @@ -2354,7 +2367,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i - with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): outputs = self._run_model(