From 44a2b3494e5794dbc9859922bfb1e46c383bb12b Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 25 Jun 2025 21:39:33 +0000 Subject: [PATCH] add attention splitting to dummy runs Signed-off-by: Sage Moore --- vllm/v1/attention/backends/mla/common.py | 8 +-- vllm/v1/attention/backends/mla/flashmla.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 75 +++++++++++++++------- 3 files changed, 58 insertions(+), 27 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d82afa5b630fd..b2c3d035a62cd 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -630,10 +630,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): m.max_query_len = 1 # decode-only # Update state usually set in reorder_batch. - self._num_decodes = m.num_reqs - self._num_decode_tokens = m.num_actual_tokens - self._num_prefills = 0 - self._num_prefill_tokens = 0 + # self._num_decodes = m.num_reqs + # self._num_decode_tokens = m.num_actual_tokens + # self._num_prefills = 0 + # self._num_prefill_tokens = 0 return self.build(0, m) def use_cascade_attention(self, *args, **kwargs) -> bool: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2b42e71053c81..476e11bb4cebe 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -77,7 +77,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): n = num_splits.size(0) logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}") - if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1]: + if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2: # First time around (CUDAGraph capture), allocate the static buffer if self.cg_buf_tile_scheduler_metadata is None: self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 76b7b6b2f1e81..be9136fef240e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -228,7 +228,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) - + logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}") self.full_cuda_graph = self.compilation_config.full_cuda_graph self.full_cuda_graph = True logger.info(f"full_cuda_graph {self.full_cuda_graph}") @@ -558,7 +558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.input_batch.refresh_sampling_metadata() def _ubatch_split( - self, query_start_loc_np: torch.Tensor, + self, max_num_scheduled_tokens: int, scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -707,7 +707,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( - self.query_start_loc_np, max_num_scheduled_tokens, + max_num_scheduled_tokens, scheduler_output) should_ubatch = self.should_ubatch(True if ubatch_slices else False) # Don't attempt to microbatch unless every other DP worker is also microbatching @@ -1343,6 +1343,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens_padded = num_tokens_unpadded + logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}") if (self.use_cuda_graph and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. @@ -2279,7 +2280,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # for now. should_ubatch = num_tokens >= \ self.parallel_config.microbatching_token_threshold and \ - allow_microbatching + allow_microbatching and capture_attn_cudagraph # _dummy_run doesn't go through _prepare_inputs so # we synchronize with other DP ranks here should_ubatch = self.should_ubatch(allow_microbatching) @@ -2304,9 +2305,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - attn_metadata: Optional[dict[str, Any]] = None + ubatch_slices = None + # We currently only microbatch if the number of tokens is + # over a certain threshold. + if should_ubatch: + # We only support decode-only cudagraphs + assert num_reqs == num_tokens + assert num_tokens % 2 == 0 + ubatch_slices = [(slice(0, num_reqs // 2), + slice(0, num_tokens // 2)), + (slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens))] + + + # attn_metadata: Optional[dict[str, Any]] = None + attn_metadata: Optional[PerLayerAttnMetadata]= None if capture_attn_cudagraph: attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. @@ -2316,39 +2333,53 @@ class GPUModelRunner(LoRAModelRunnerMixin): non_blocking=True) seq_lens = self.seq_lens[:num_reqs] + max_query_len = num_tokens + if ubatch_slices is not None: + max_query_len = 1 common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=num_tokens, - max_query_len=num_tokens, + max_query_len=max_query_len, ) for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + if ubatch_slices is not None: + for ubid, (req_slice, token_slice) in enumerate(ubatch_slices): + # Run a dummy batch if its a empty ubatch + if token_slice.stop <= token_slice.start: + attn_metadata_i = None + else: + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id]. + build_slice( + req_slice=req_slice, + token_slice=token_slice, + max_query_len=max_query_len, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + )) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + # assert attn_metadata_i is not None + # What if it's None? Do we still add it to the list? + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + attn_metadata_i = self.attn_metadata_builders[ + kv_cache_group_id].build_for_cudagraph_capture( + common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( - common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - dummy_microbatches = None - # We currently only microbatch if the number of tokens is - # over a certain threshold. - if should_ubatch: - assert num_tokens % 2 == 0 - dummy_microbatches = [(slice(0, num_tokens // 2), - slice(0, num_tokens // 2)), - (slice(num_tokens // 2, num_tokens), - slice(num_tokens // 2, num_tokens))] with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): outputs = self._run_model( attn_metadata, num_tokens, - ubatch_slices=dummy_microbatches, + ubatch_slices=ubatch_slices, is_dummy_run=True, num_tokens_across_dp=num_tokens_across_dp, build_cuda_graph=build_cuda_graph