From 82ae694de6d9c3d845b884fb054437b0e7ea51d5 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 3 Jul 2025 20:47:39 +0000 Subject: [PATCH] comments cleanup etc Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 15 ++++++++------- vllm/v1/worker/ubatching.py | 23 +++++++---------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 160a276426ff4..51a964817505a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -593,6 +593,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): total_num_scheduled_tokens >= \ self.parallel_config.microbatching_token_threshold \ and max_num_scheduled_tokens == 1 + # Don't microbatch unless every other DP worker is also microbatching should_ubatch = self.should_ubatch(should_attempt_ubatching) if not should_ubatch: @@ -611,20 +612,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): ] num_pad_tokens = 0 num_tokens_after_padding = None - ubatch_bailout = False + ubatch_abort = False num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) if num_pad_tokens > 0: + # Check if the padding would result in an empty second ubatch. + # If so abort ubatching if num_pad_tokens < scheduler_output.total_num_scheduled_tokens: self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens) else: - # We bail out of ubatching here. This accounts for the case where - # the padding would result in an "empty" second ubatch. - # TODO: just make the second ubatch a dummy ubatch - ubatch_bailout = True + ubatch_abort = True # Note that if we are attempting to ubatch by this point then we know that no - # DP ranks are doing dummy runs - should_ubatch = self.should_ubatch(False if ubatch_bailout else True) + # DP ranks are doing dummy runs. Meaning, we don't need a second call to + # should_ubatch in _dummy_run + should_ubatch = self.should_ubatch(False if ubatch_abort else True) if not should_ubatch: return (None, 0, None) return (ubatch_slices, num_pad_tokens, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 160fbe75411c5..b16854c9375a5 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -65,26 +65,16 @@ class UBatchContext: self.current_stream = stream torch.cuda.set_stream(self.current_stream) - def ctx_valid_state(self): - assert forward_context._forward_context == self.forward_context - assert current_stream() == self.current_stream - assert not self.cpu_wait_event.is_set() - pass - def _signal_comm_done(self): - self.ctx_valid_state() self.gpu_comm_done_event.record(self.comm_stream) def _signal_compute_done(self): - self.ctx_valid_state() self.gpu_compute_done_event.record(self.compute_stream) def _wait_compute_done(self): - self.ctx_valid_state() self.comm_stream.wait_event(self.gpu_compute_done_event) def _wait_comm_done(self): - self.ctx_valid_state() self.compute_stream.wait_event(self.gpu_comm_done_event) def stream_string(self): @@ -96,29 +86,30 @@ class UBatchContext: return "COMM" def _cpu_yield(self): - self.ctx_valid_state() + # It is critical for correctness that only one thread is running + # at a time. These asserts just make sure that this is the only + # thread running before waking the other one up and going to sleep + assert forward_context._forward_context == self.forward_context + assert current_stream() == self.current_stream + assert not self.cpu_wait_event.is_set() + self.cpu_signal_event.set() self.cpu_wait_event.wait() self.cpu_wait_event.clear() self._restore_context() - self.ctx_valid_state() def yield_and_switch_from_compute_to_comm(self): assert current_stream() == self.compute_stream - self.ctx_valid_state() self._signal_compute_done() self._cpu_yield() - self.ctx_valid_state() assert self.current_stream == self.compute_stream self.update_stream(self.comm_stream) self._wait_compute_done() def yield_and_switch_from_comm_to_compute(self): assert current_stream() == self.comm_stream - self.ctx_valid_state() self._signal_comm_done() self._cpu_yield() - self.ctx_valid_state() assert self.current_stream == self.comm_stream self.update_stream(self.compute_stream) self._wait_comm_done()