diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e7cd418f115d..82e8f2d8472e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2989,13 +2989,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # We currently only microbatch if the number of tokens is # over a certain threshold. if self.parallel_config.enable_dbo and allow_microbatching: - ubatch_slices, num_tokens_after_padding = ubatch_split( + ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split( num_scheduled_tokens, total_num_scheduled_tokens, total_num_scheduled_tokens, uniform_decode=uniform_decode, vllm_config=self.vllm_config, ) + # Currently when DBO is enabled `ubatch_split` returns + # the num_tokens_after_padding for a single ubatch, but we have 2 + # TODO(sage,lucas): this is cruft that should be addressed in the + # padding refactor. + if ubatch_num_tokens_after_padding is not None: + num_tokens_after_padding = ubatch_num_tokens_after_padding * 2 # If we failed to microbatch, currently need to resynchronize # TODO(lucas,sage): we should be able to avoid this second sync by @@ -3112,8 +3118,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # filter out the valid batch descriptor _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) + BatchDescriptor(num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode)) \ + if not is_profile else (CUDAGraphMode.NONE, None) if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture @@ -3125,7 +3132,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode = _cg_mode if ubatch_slices is not None: - num_tokens = num_tokens // 2 + # Adjust values to reflect a single ubatch. + # TODO(sage,lucas): this is cruft that should be addressed in + # the padding refactor. + num_tokens_after_padding = ubatch_slices[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_after_padding + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config,