From 60499f63afeced1de321aebca887112259c25d4f Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Sat, 7 Jun 2025 16:16:26 +0000 Subject: [PATCH] padding is getting correctness but there are still some edgecases tripping asserts Signed-off-by: Sage Moore --- vllm/forward_context.py | 7 +++---- vllm/v1/worker/gpu_model_runner.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 0e1575d71647b..e422b31de00fb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -43,9 +43,9 @@ class DPMetadata: device="cpu", dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group - print("STARTING AR") + # print("STARTING AR num_tokens_across_dp") dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - print("finishing") + # print("finishing num_tokens_across_dp") return num_tokens_tensor @staticmethod @@ -56,10 +56,9 @@ class DPMetadata: device="cpu", dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group - print("Starting AR") dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group) - print("FINISHING AR") result: bool = bool(torch.all(should_ubatch_tensor == 1).item()) + # print(f"FINISHING AR should_ubatch_across_dp {result} {should_ubatch_tensor}") return result @staticmethod diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fe020bd66f88d..fd34451d815d5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -658,6 +658,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np, max_num_scheduled_tokens, scheduler_output) should_ubatch = self.should_ubatch(True if ubatch_slices else False) + if should_ubatch: + assert ubatch_slices # Don't attempt to microbatch unless every other DP worker is also microbatching if not should_ubatch and ubatch_slices: ubatch_slices = None @@ -665,6 +667,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_pad_tokens = 0 num_tokens_after_padding = None if ubatch_slices: + assert should_ubatch num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) if num_pad_tokens > 0: self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens) @@ -1425,7 +1428,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: if use_dummy_input: - print("MAKING DUMMY BATCH") + # print("MAKING DUMMY BATCH") # assert num_dummy_tokens == 1 return self._get_dummy_model_inputs(num_dummy_tokens) else: @@ -1451,12 +1454,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): @torch.inference_mode() def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input): - print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) + # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) model_output = _run(token_slice, ubatch_ctx, use_dummy_input) if save_results: results.append((ubatch_ctx.id, model_output)) - print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) + # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run, num_tokens_across_dp) -> torch.Tensor: @@ -1479,12 +1482,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens = num_dummy_tokens if is_dummy_ubatch or \ is_dummy_run else (tokens_slice.stop - tokens_slice.start) + # if num_tokens_across_dp is None: + # print(f"GOING TO CALL AR: {i}") ubatch_ctxs[i].forward_context = create_forward_context( attn_metadata[i] if attn_metadata is not None else None, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp if i == 1 else None) + num_tokens_across_dp=num_tokens_across_dp) thread = threading.Thread(target=_ubatch_thread, args=( @@ -1554,8 +1559,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): if ubatch_slices and num_pad_tokens > 0: num_scheduled_tokens += num_pad_tokens self.pad_out_ubatch_second_stage(ubatch_slices, num_scheduled_tokens) - else: - num_tokens_after_padding = None # Run the decoder. # Use persistent buffers for CUDA graphs.