diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 34e62531a9c4f..78b74a2078cbb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -571,40 +571,52 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _ubatch_split( self, max_num_scheduled_tokens: int, - scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]: + scheduler_output: "SchedulerOutput" + ) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_reqs = self.input_batch.num_reqs - - if self.parallel_config.enable_microbatching and \ + should_attempt_ubatching = \ + self.parallel_config.enable_microbatching and \ total_num_scheduled_tokens >= \ self.parallel_config.microbatching_token_threshold \ - and max_num_scheduled_tokens == 1: - # For pure decode we can just create ubatchs by cutting the request - # in half - b0_reqs_end = num_reqs // 2 - b0_tokens_end = total_num_scheduled_tokens // 2 - assert b0_reqs_end < num_reqs and \ - b0_tokens_end < total_num_scheduled_tokens - return [ - (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), - (slice(b0_reqs_end, num_reqs), - slice(b0_tokens_end, total_num_scheduled_tokens)), - ] + and max_num_scheduled_tokens == 1 + # Don't microbatch unless every other DP worker is also microbatching + should_ubatch = self.should_ubatch(should_attempt_ubatching) + if not should_ubatch: + return (None, 0, None) - # if self.parallel_config.enable_microbatching and \ - # self.parallel_config.always_microbatch_if_enabled: - # print(f"PREFIL RUN total_num_scheduled_tokens: {total_num_scheduled_tokens} max_num_scheduled_tokens {max_num_scheduled_tokens}") - # TODO we can do something more advanced here to try to balance, - # i.e. split to the left of `total_num_scheduled_tokens // 2` if it - # is more balanced - # req_split_id = np.argmax( - # query_start_loc_np > (total_num_scheduled_tokens // 2)) - # return [(slice(0, req_split_id), - # slice(0, query_start_loc_np[req_split_id])), - # (slice(req_split_id, num_reqs), - # slice(query_start_loc_np[req_split_id], - # total_num_scheduled_tokens))] - return None + # For pure decode we can just create ubatchs by cutting the request + # in half + b0_reqs_end = num_reqs // 2 + b0_tokens_end = total_num_scheduled_tokens // 2 + assert b0_reqs_end < num_reqs and \ + b0_tokens_end < total_num_scheduled_tokens + ubatch_slices = [ + (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), + (slice(b0_reqs_end, num_reqs), + slice(b0_tokens_end, total_num_scheduled_tokens)), + ] + num_pad_tokens = 0 + num_tokens_after_padding = None + ubatch_bailout = False + num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) + logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}") + if num_pad_tokens > 0: + if num_pad_tokens < scheduler_output.total_num_scheduled_tokens: + self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens) + else: + # We bail out of ubatching here. This accounts for the case where + # the padding would result in an "empty" second ubatch. + # TODO: just make the second ubatch a dummy ubatch + # logger.info("FALLING BACK AND DISABLING UBATCHING") + ubatch_bailout = True + + # Note that if we are attempting to ubatch by this point then we know that no + # DP ranks are doing dummy runs + should_ubatch = self.should_ubatch(False if ubatch_bailout else True) + if not should_ubatch: + return (None, 0, None) + return (ubatch_slices, num_pad_tokens, num_tokens_after_padding) def _get_cumsum_and_arange( self, @@ -719,55 +731,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( - max_num_scheduled_tokens, - scheduler_output) - should_ubatch = self.should_ubatch(True if ubatch_slices else False) - # Don't attempt to microbatch unless every other DP worker is also microbatching - if not should_ubatch: - ubatch_slices = None - - num_pad_tokens = 0 - num_tokens_after_padding = None - ubatch_bailout = False - if ubatch_slices: - # logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}") - assert should_ubatch - num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) - logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}") - # logger.info("UBATCH PADDING DONE") - if num_pad_tokens > 0: - if num_pad_tokens < scheduler_output.total_num_scheduled_tokens: - self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens) - else: - # We bail out of ubatching here. This accounts for the case where - # the padding would result in an "empty" second ubatch. - # TODO: just make the second ubatch a dummy ubatch - # logger.info("FALLING BACK AND DISABLING UBATCHING") - ubatch_bailout = True - - # Note that if we are attempting to ubatch by this point then we know that no - # DP ranks are doing dummy runs - if ubatch_slices: - should_ubatch = self.should_ubatch(False if ubatch_bailout else True) - if not should_ubatch: - # logger.info("SUCCESSFULLY BAILED OUT") - num_pad_tokens = 0 - num_tokens_after_padding = None - ubatch_slices = None - - - # This AR is only necessary in the case described above where - # the second ubatch ends up being empty. NOte if you delete this go delete - # the second should_ubatch call in _dummy_run - # should_ubatch = self.should_ubatch(True if ubatch_slices else False) - # if not should_ubatch: - # num_pad_tokens = 0 - # num_tokens_after_padding = None - # ubatch_slices = None - - - + ubatch_slices, num_pad_tokens, num_tokens_after_padding = \ + self._ubatch_split(max_num_scheduled_tokens, + scheduler_output) self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] +