cleanup some of the should_ubatch logic

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-03 14:22:53 +00:00
parent 83caef8bac
commit 7cc5a549ad

View File

@ -571,40 +571,52 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _ubatch_split(
self,
max_num_scheduled_tokens: int,
scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]:
scheduler_output: "SchedulerOutput"
) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_reqs = self.input_batch.num_reqs
if self.parallel_config.enable_microbatching and \
should_attempt_ubatching = \
self.parallel_config.enable_microbatching and \
total_num_scheduled_tokens >= \
self.parallel_config.microbatching_token_threshold \
and max_num_scheduled_tokens == 1:
# For pure decode we can just create ubatchs by cutting the request
# in half
b0_reqs_end = num_reqs // 2
b0_tokens_end = total_num_scheduled_tokens // 2
assert b0_reqs_end < num_reqs and \
b0_tokens_end < total_num_scheduled_tokens
return [
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
(slice(b0_reqs_end, num_reqs),
slice(b0_tokens_end, total_num_scheduled_tokens)),
]
and max_num_scheduled_tokens == 1
# Don't microbatch unless every other DP worker is also microbatching
should_ubatch = self.should_ubatch(should_attempt_ubatching)
if not should_ubatch:
return (None, 0, None)
# if self.parallel_config.enable_microbatching and \
# self.parallel_config.always_microbatch_if_enabled:
# print(f"PREFIL RUN total_num_scheduled_tokens: {total_num_scheduled_tokens} max_num_scheduled_tokens {max_num_scheduled_tokens}")
# TODO we can do something more advanced here to try to balance,
# i.e. split to the left of `total_num_scheduled_tokens // 2` if it
# is more balanced
# req_split_id = np.argmax(
# query_start_loc_np > (total_num_scheduled_tokens // 2))
# return [(slice(0, req_split_id),
# slice(0, query_start_loc_np[req_split_id])),
# (slice(req_split_id, num_reqs),
# slice(query_start_loc_np[req_split_id],
# total_num_scheduled_tokens))]
return None
# For pure decode we can just create ubatchs by cutting the request
# in half
b0_reqs_end = num_reqs // 2
b0_tokens_end = total_num_scheduled_tokens // 2
assert b0_reqs_end < num_reqs and \
b0_tokens_end < total_num_scheduled_tokens
ubatch_slices = [
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
(slice(b0_reqs_end, num_reqs),
slice(b0_tokens_end, total_num_scheduled_tokens)),
]
num_pad_tokens = 0
num_tokens_after_padding = None
ubatch_bailout = False
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
if num_pad_tokens > 0:
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
else:
# We bail out of ubatching here. This accounts for the case where
# the padding would result in an "empty" second ubatch.
# TODO: just make the second ubatch a dummy ubatch
# logger.info("FALLING BACK AND DISABLING UBATCHING")
ubatch_bailout = True
# Note that if we are attempting to ubatch by this point then we know that no
# DP ranks are doing dummy runs
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
if not should_ubatch:
return (None, 0, None)
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
def _get_cumsum_and_arange(
self,
@ -719,55 +731,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
max_num_scheduled_tokens,
scheduler_output)
should_ubatch = self.should_ubatch(True if ubatch_slices else False)
# Don't attempt to microbatch unless every other DP worker is also microbatching
if not should_ubatch:
ubatch_slices = None
num_pad_tokens = 0
num_tokens_after_padding = None
ubatch_bailout = False
if ubatch_slices:
# logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}")
assert should_ubatch
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
# logger.info("UBATCH PADDING DONE")
if num_pad_tokens > 0:
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
else:
# We bail out of ubatching here. This accounts for the case where
# the padding would result in an "empty" second ubatch.
# TODO: just make the second ubatch a dummy ubatch
# logger.info("FALLING BACK AND DISABLING UBATCHING")
ubatch_bailout = True
# Note that if we are attempting to ubatch by this point then we know that no
# DP ranks are doing dummy runs
if ubatch_slices:
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
if not should_ubatch:
# logger.info("SUCCESSFULLY BAILED OUT")
num_pad_tokens = 0
num_tokens_after_padding = None
ubatch_slices = None
# This AR is only necessary in the case described above where
# the second ubatch ends up being empty. NOte if you delete this go delete
# the second should_ubatch call in _dummy_run
# should_ubatch = self.should_ubatch(True if ubatch_slices else False)
# if not should_ubatch:
# num_pad_tokens = 0
# num_tokens_after_padding = None
# ubatch_slices = None
ubatch_slices, num_pad_tokens, num_tokens_after_padding = \
self._ubatch_split(max_num_scheduled_tokens,
scheduler_output)
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +