diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 15afe390aa51f..1b9fcb63a7804 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1300,12 +1300,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(token_slice, use_dummy_input) with context: + if isinstance(context, UBatchContext): + print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape}") model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if isinstance(context, UBatchContext): + print(f"Ran ubatch {context.id}putput {model_output.shape}") if isinstance(context, UBatchContext): # Clone before we leave the ubatch context model_output = model_output.clone() diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 3688ae8129db6..aab9c2a8d68a7 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -22,6 +22,7 @@ class UBatchContext: cpu_signal_event: threading.Event, gpu_wait_event: torch.cuda.Event, gpu_signal_event: torch.cuda.Event, + gpu_wait_on_launch: bool = False, schedule="default"): self.id = id self.stream = stream @@ -33,18 +34,27 @@ class UBatchContext: self.gpu_signal_event = gpu_signal_event self.schedule = schedule self.done_event = torch.cuda.Event() + self.gpu_wait_on_launch = gpu_wait_on_launch def __enter__(self): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = self self._cpu_wait() - self.gpu_stream_wait() + start_event = torch.cuda.Event() + self.original_stream.record_event(start_event) + self.stream.wait_event(start_event) + print("Starting ubatch %d" % self.id) + if self.gpu_wait_on_launch: + self.gpu_stream_wait() return self def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = None torch.cuda.set_stream(self.original_stream) + print("Finishing ubatch %d" % self.id) + self._signal() + self._signal() self._signal() return False @@ -72,9 +82,11 @@ class UBatchContext: def gpu_stream_wait(self): self.stream.wait_event(self.gpu_wait_event) - def yield_(self, gpu_wait: bool = True): + def _yield(self, gpu_wait: bool = True): + print("Yielding ubatch %d" % self.id) self._signal() self._cpu_wait() + print("Resuming ubatch %d" % self.id) if gpu_wait: self.gpu_stream_wait() @@ -91,11 +103,19 @@ class UBatchContext: _CURRENT_CONTEXT: dict = {} -def yield_impl(schedule="default"): +def get_current_ubatch_context() -> Optional[UBatchContext]: + global _CURRENT_CONTEXT + """ + Get the current UBatchContext for the current thread. + """ + return _CURRENT_CONTEXT.get(threading.get_ident(), None) + +def yield_impl(schedule="default", gpu_wait: bool = True): # Perform the barrier if a context exists for this thread - ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None) - if ctx is not None and ctx.schedule == schedule: - ctx.yield_() + ctx = get_current_ubatch_context() + print("you are in yield_impl", ctx) + if ctx is not None: + ctx._yield(gpu_wait=gpu_wait) # 2) Register kernel for CUDA, mark as mutating to prevent the compiler from @@ -137,6 +157,7 @@ def make_ubatch_context_chain( cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], gpu_wait_event=gpu_events[i], gpu_signal_event=gpu_events[(i + 1) % num_micro_batches], + gpu_wait_on_launch=(i > 0), ) ctxs.append(ctx)