mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 17:57:08 +08:00
cleanup some of the should_ubatch logic
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
83caef8bac
commit
7cc5a549ad
@ -571,40 +571,52 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _ubatch_split(
|
||||
self,
|
||||
max_num_scheduled_tokens: int,
|
||||
scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]:
|
||||
scheduler_output: "SchedulerOutput"
|
||||
) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
||||
if self.parallel_config.enable_microbatching and \
|
||||
should_attempt_ubatching = \
|
||||
self.parallel_config.enable_microbatching and \
|
||||
total_num_scheduled_tokens >= \
|
||||
self.parallel_config.microbatching_token_threshold \
|
||||
and max_num_scheduled_tokens == 1:
|
||||
# For pure decode we can just create ubatchs by cutting the request
|
||||
# in half
|
||||
b0_reqs_end = num_reqs // 2
|
||||
b0_tokens_end = total_num_scheduled_tokens // 2
|
||||
assert b0_reqs_end < num_reqs and \
|
||||
b0_tokens_end < total_num_scheduled_tokens
|
||||
return [
|
||||
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||
(slice(b0_reqs_end, num_reqs),
|
||||
slice(b0_tokens_end, total_num_scheduled_tokens)),
|
||||
]
|
||||
and max_num_scheduled_tokens == 1
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
should_ubatch = self.should_ubatch(should_attempt_ubatching)
|
||||
if not should_ubatch:
|
||||
return (None, 0, None)
|
||||
|
||||
# if self.parallel_config.enable_microbatching and \
|
||||
# self.parallel_config.always_microbatch_if_enabled:
|
||||
# print(f"PREFIL RUN total_num_scheduled_tokens: {total_num_scheduled_tokens} max_num_scheduled_tokens {max_num_scheduled_tokens}")
|
||||
# TODO we can do something more advanced here to try to balance,
|
||||
# i.e. split to the left of `total_num_scheduled_tokens // 2` if it
|
||||
# is more balanced
|
||||
# req_split_id = np.argmax(
|
||||
# query_start_loc_np > (total_num_scheduled_tokens // 2))
|
||||
# return [(slice(0, req_split_id),
|
||||
# slice(0, query_start_loc_np[req_split_id])),
|
||||
# (slice(req_split_id, num_reqs),
|
||||
# slice(query_start_loc_np[req_split_id],
|
||||
# total_num_scheduled_tokens))]
|
||||
return None
|
||||
# For pure decode we can just create ubatchs by cutting the request
|
||||
# in half
|
||||
b0_reqs_end = num_reqs // 2
|
||||
b0_tokens_end = total_num_scheduled_tokens // 2
|
||||
assert b0_reqs_end < num_reqs and \
|
||||
b0_tokens_end < total_num_scheduled_tokens
|
||||
ubatch_slices = [
|
||||
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||
(slice(b0_reqs_end, num_reqs),
|
||||
slice(b0_tokens_end, total_num_scheduled_tokens)),
|
||||
]
|
||||
num_pad_tokens = 0
|
||||
num_tokens_after_padding = None
|
||||
ubatch_bailout = False
|
||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
||||
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
|
||||
if num_pad_tokens > 0:
|
||||
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
||||
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
|
||||
else:
|
||||
# We bail out of ubatching here. This accounts for the case where
|
||||
# the padding would result in an "empty" second ubatch.
|
||||
# TODO: just make the second ubatch a dummy ubatch
|
||||
# logger.info("FALLING BACK AND DISABLING UBATCHING")
|
||||
ubatch_bailout = True
|
||||
|
||||
# Note that if we are attempting to ubatch by this point then we know that no
|
||||
# DP ranks are doing dummy runs
|
||||
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
|
||||
if not should_ubatch:
|
||||
return (None, 0, None)
|
||||
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
@ -719,55 +731,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
|
||||
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
|
||||
max_num_scheduled_tokens,
|
||||
scheduler_output)
|
||||
should_ubatch = self.should_ubatch(True if ubatch_slices else False)
|
||||
# Don't attempt to microbatch unless every other DP worker is also microbatching
|
||||
if not should_ubatch:
|
||||
ubatch_slices = None
|
||||
|
||||
num_pad_tokens = 0
|
||||
num_tokens_after_padding = None
|
||||
ubatch_bailout = False
|
||||
if ubatch_slices:
|
||||
# logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}")
|
||||
assert should_ubatch
|
||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
||||
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
|
||||
# logger.info("UBATCH PADDING DONE")
|
||||
if num_pad_tokens > 0:
|
||||
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
||||
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
|
||||
else:
|
||||
# We bail out of ubatching here. This accounts for the case where
|
||||
# the padding would result in an "empty" second ubatch.
|
||||
# TODO: just make the second ubatch a dummy ubatch
|
||||
# logger.info("FALLING BACK AND DISABLING UBATCHING")
|
||||
ubatch_bailout = True
|
||||
|
||||
# Note that if we are attempting to ubatch by this point then we know that no
|
||||
# DP ranks are doing dummy runs
|
||||
if ubatch_slices:
|
||||
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
|
||||
if not should_ubatch:
|
||||
# logger.info("SUCCESSFULLY BAILED OUT")
|
||||
num_pad_tokens = 0
|
||||
num_tokens_after_padding = None
|
||||
ubatch_slices = None
|
||||
|
||||
|
||||
# This AR is only necessary in the case described above where
|
||||
# the second ubatch ends up being empty. NOte if you delete this go delete
|
||||
# the second should_ubatch call in _dummy_run
|
||||
# should_ubatch = self.should_ubatch(True if ubatch_slices else False)
|
||||
# if not should_ubatch:
|
||||
# num_pad_tokens = 0
|
||||
# num_tokens_after_padding = None
|
||||
# ubatch_slices = None
|
||||
|
||||
|
||||
|
||||
ubatch_slices, num_pad_tokens, num_tokens_after_padding = \
|
||||
self._ubatch_split(max_num_scheduled_tokens,
|
||||
scheduler_output)
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user