misc padding fixes

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-06 23:24:51 +00:00
parent a00dabcb33
commit 05ddc34913
3 changed files with 52 additions and 26 deletions

View File

@ -43,7 +43,9 @@ class DPMetadata:
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
print("STARTING AR")
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
print("finishing")
return num_tokens_tensor
@staticmethod
@ -54,7 +56,9 @@ class DPMetadata:
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
print("Starting AR")
dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group)
print("FINISHING AR")
result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
return result
@ -80,6 +84,7 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
# print(f"num_tokens_across_dp {num_tokens_across_dp} batchsize {batchsize}")
assert (num_tokens_across_dp is None
or num_tokens_across_dp[dp_rank] == batchsize)
if num_tokens_across_dp is None:

View File

@ -123,8 +123,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
# torch.cuda.synchronize()

View File

@ -573,7 +573,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
Optional[SpecDecodeMetadata], Optional[UBatchSlices]]:
Optional[SpecDecodeMetadata], Optional[UBatchSlices],
int, Optional[torch.Tensor]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@ -661,6 +662,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not should_ubatch and ubatch_slices:
ubatch_slices = None
num_pad_tokens = 0
num_tokens_after_padding = None
if ubatch_slices:
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
if num_pad_tokens > 0:
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
@ -782,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, logits_indices, spec_decode_metadata,
ubatch_slices)
ubatch_slices, num_pad_tokens, num_tokens_after_padding)
def _compute_cascade_attn_prefix_len(
self,
@ -1261,30 +1270,42 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
dtype=torch.int32)
padded_first_ubatch_slice = slice(0, max_tokens_across_dp)
padded_second_ubatch_slice = slice(max_tokens_across_dp, max_tokens_across_dp * 2)
assert max_tokens_across_dp <= 2 * max_tokens_per_ubatch_local, \
f"max_tokens_across_dp: {max_tokens_across_dp} max_tokens_per_ubatch{max_tokens_per_ubatch_local}"
num_pad_tokens = (max_tokens_across_dp * 2) - \
(first_ubatch_num_tokens + second_ubatch_num_tokens)
return num_pad_tokens, num_tokens_after_padding
assert padded_first_ubatch_slice.stop - padded_first_ubatch_slice.start == \
padded_second_ubatch_slice.stop - padded_second_ubatch_slice.start
# This doesn't actually pad the ubatch slices. It just shifts the
# split point to the correct value so that padding can be applied
# to the second ubatch later. Should be called after ubatch
# slicing but before attention meta data creation
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
num_pad_tokens: int):
original_num_tokens = ubatch_slices[1][1].stop
assert num_pad_tokens < original_num_tokens
total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens)
ubatch_slices[0] = (padded_first_ubatch_slice, padded_first_ubatch_slice)
ubatch_slices[1] = (padded_first_ubatch_slice, padded_second_ubatch_slice)
# Need to assert that none of the padding is on the first ubatch
assert padded_first_ubatch_slice.stop - padded_first_ubatch_slice.start
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice)
# if (num_pad_tokens_first_ubatch > 0):
# print(f"FIRST UBATCH PADDING {num_pad_tokens_first_ubatch} TOTAL: {max_tokens_across_dp_cpu} ORIGINAL{first_ubatch_num_tokens}")
# if (num_pad_tokens_second_ubatch > 0):
# print(f"SECOND UBATCH PADDING {num_pad_tokens_second_ubatch} TOTAL: {max_tokens_across_dp_cpu} ORIGINAL{second_ubatch_num_tokens}")
# print(f"num padded tokens: {num_pad_tokens} num tokens tensor: {num_tokens_after_padding} first num_tokens: {first_ubatch_num_tokens} second num tokens {second_ubatch_num_tokens}")
# This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just extends
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
# TODO Add asserts to make sure stage one ran
padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, num_total_tokens)
ubatch_slices[1] = (ubatch_slices[1][0], padded_second_ubatch_slice)
num_pad_tokens = (max_tokens_across_dp * 2) - \
(first_ubatch_num_tokens + second_ubatch_num_tokens)
print(f"num padded tokens: {num_pad_tokens} num tokens tensor: {num_tokens_after_padding} first num_tokens: {first_ubatch_num_tokens} second num tokens {second_ubatch_num_tokens}")
return num_pad_tokens, num_tokens_after_padding
def should_ubatch(self, should_ubatch: bool) -> bool:
dp_size = self.vllm_config.parallel_config.data_parallel_size
@ -1430,12 +1451,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results:
results.append((ubatch_ctx.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
@ -1498,7 +1519,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
# print("RUN NORMAL")
# No padding for the non ubatch case
assert not num_tokens_across_dp
assert num_tokens_across_dp is None
model_output = _run(
slice(0, num_scheduled_tokens),
set_forward_context(attn_metadata,
@ -1524,16 +1545,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output)
# num_scheduled_tokens_old = scheduler_output.total_num_scheduled_tokens
# num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_scheduled_tokens_old)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices = (
attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices, num_pad_tokens, num_tokens_after_padding = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_tokens_after_padding = None
if ubatch_slices:
num_pad_tokens, num_tokens_after_padding = \
self.get_dp_padding_ubatch(ubatch_slices)
if ubatch_slices and num_pad_tokens > 0:
num_scheduled_tokens += num_pad_tokens
self.pad_out_ubatch_second_stage(ubatch_slices, num_scheduled_tokens)
else:
num_tokens_after_padding = None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
self.maybe_setup_kv_connector(scheduler_output)