mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 07:57:03 +08:00
ubatch padding should work now
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
8a75b3a1e5
commit
a8675b7d98
@ -1231,9 +1231,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def pad_ubatch(self, target_num_tokens, ubatch_slice: UbatchSlice):
|
||||
pass
|
||||
|
||||
def get_dp_padding_ubatch(self,
|
||||
ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]:
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
@ -1245,7 +1242,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
first_ubatch_num_tokens = first_ubatch_slice[1].stop - first_ubatch_slice[1].start
|
||||
second_ubatch_num_tokens = second_ubatch_slice[1].stop - second_ubatch_slice[1].start
|
||||
|
||||
num_tokens = max(first_ubatch_num_tokens, second_ubatch_num_tokens)
|
||||
max_tokens_per_ubatch = max(first_ubatch_num_tokens, second_ubatch_num_tokens)
|
||||
|
||||
# For DP: Don't pad when setting enforce_eager.
|
||||
# This lets us set enforce_eager on the prefiller in a P/D setup and
|
||||
@ -1259,19 +1256,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return 0, None
|
||||
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
num_tokens, dp_size, dp_rank)
|
||||
max_tokens_per_ubatch, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||
dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
# Note that this num_pad_tokens will actually
|
||||
# be the number of tokens added to each ubatch.
|
||||
# Meaning 2*num_pad_tokens are added to each DP rank
|
||||
num_pad_tokens = max_tokens_across_dp_cpu - num_tokens
|
||||
self.pad_ubatch(num_pad_tokens, first_ubatch_slice)
|
||||
self.pad_ubatch(num_pad_tokens, second_ubatch_slice)
|
||||
return num_pad_tokens, num_tokens_after_padding
|
||||
|
||||
num_pad_tokens_first_ubatch = max_tokens_across_dp_cpu - first_ubatch_num_tokens
|
||||
num_pad_tokens_second_ubatch = max_tokens_across_dp_cpu - second_ubatch_num_tokens
|
||||
|
||||
padded_first_ubatch_slice = slice(0, max_tokens_across_dp_cpu)
|
||||
padded_second_ubatch_slice = slice(max_tokens_across_dp_cpu, 2 * max_tokens_across_dp_cpu)
|
||||
|
||||
ubatch_slices[0] = (ubatch_slices[0][0], padded_first_ubatch_slice)
|
||||
ubatch_slices[1] = (ubatch_slices[1][0], padded_second_ubatch_slice)
|
||||
|
||||
return num_pad_tokens_first_ubatch + num_pad_tokens_second_ubatch, num_tokens_after_padding
|
||||
|
||||
def should_ubatch(self, should_ubatch: bool) -> bool:
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
@ -1513,17 +1514,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
num_pad_tokens, num_tokens_after_padding = \
|
||||
self.get_dp_padding_ubatch(ubatch_slices)
|
||||
num_scheduled_tokens += num_pad_tokens
|
||||
num_tokens_after_padding = None
|
||||
if ubatch_slices:
|
||||
num_pad_tokens, num_tokens_after_padding = \
|
||||
self.get_dp_padding_ubatch(ubatch_slices)
|
||||
num_scheduled_tokens += num_pad_tokens
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
model_output = self._run_model(
|
||||
attn_metadata,
|
||||
num_scheduled_tokens,
|
||||
ubatch_slices,
|
||||
scheduler_output,
|
||||
attn_metadata=attn_metadata,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
ubatch_slices=ubatch_slices,
|
||||
scheduler_output=scheduler_output,
|
||||
num_tokens_across_dp=num_tokens_after_padding
|
||||
)
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user