diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 35eef966e2717..c0e89ff6c40dc 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -336,21 +336,10 @@ class FusedMoEModularKernel(torch.nn.Module): device=a1.device, dtype=workspace_dtype) - # if (ubatch_ctx := get_current_ubatch_context()) is not None: - # print("in modular moe, ubatch:", ubatch_ctx.id) - a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare( a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, expert_map, apply_router_weight_on_input) - # if (ubatch_ctx := get_current_ubatch_context()) is not None: - # print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts) - - print("pre synchronize") - dump_ubatching_state() - torch.cuda.synchronize(a1.device) - print("post synchronize") - fused_out = self.fused_experts.apply( a1q, w1, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5c456eb409e6c..ad88215cd6607 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1318,9 +1318,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): return model_output @torch.inference_mode() - def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt): - ubatch_ctx.stream.wait_stream(root_stream) - + def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input): model_output = _run(token_slice, ubatch_ctx, use_dummy_input) if save_results: @@ -1330,24 +1328,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): results = [] assert len(ubatch_slices) == 2, "Only two ubatches has been tested" root_stream = current_stream() - - if not hasattr(self, "ubatch_streams"): - # Create the ubatch streams - self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))] - - - # We have to be careful creating the forward contexts here otherwise we can end - # up with the dummy contexts have num_tokens set to 0 - # ubatch_fwd_ctxs = [create_forward_context( - # attn_metadata[i] if attn_metadata is not None else None, - # self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start) - # ) for i, (_, tokens_slice) in enumerate(ubatch_slices)] - ubatch_ctxs, start_hook = make_ubatch_contexts( + + ubatch_ctxs = make_ubatch_contexts( len(ubatch_slices), compute_stream=root_stream, device=self.device) - setup_done = threading.Event() - ubatch_threads = [] # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later @@ -1366,20 +1351,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): thread = threading.Thread(target=_ubatch_thread, args=( ubatch_ctxs[i], - root_stream, tokens_slice, - attn_metadata[i] if attn_metadata is not None else None, results, not is_dummy_ubatch or is_dummy_run, is_dummy_ubatch or is_dummy_run, - setup_done, )) ubatch_threads.append(thread) thread.start() - # Single the first ubatch to start - start_hook(root_stream) - for thread in ubatch_threads: thread.join() diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 1907e0509a91f..de2efbb5cc27b 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -39,6 +39,7 @@ class UBatchContext: def __enter__(self): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = self + self._restore_context() # Assume we start on the compute stream assert current_stream() == self.compute_stream, \ "Expected to start on the compute stream, but found %s" % current_stream() @@ -52,9 +53,6 @@ class UBatchContext: return False def _restore_context(self): - # When we resume i.e. switch back to this micro-batch, we make sure - # we have the correct stream and forward context - torch.cuda.set_stream(self.stream) forward_context._forward_context = self.forward_context def _signal_comm_done(self): @@ -130,20 +128,21 @@ def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="defaul pass def dump_ubatching_state(): - """ - Dump the current UBatchContext state for debugging. - """ + pass + # """ + # Dump the current UBatchContext state for debugging. + # """ - dp_rank = os.getenv("VLLM_DP_RANK", None) + # dp_rank = os.getenv("VLLM_DP_RANK", None) - for ctx in _CURRENT_CONTEXT.values(): - print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n" - f" Stream: {ctx.stream}, ({ctx.stream.query()})\n" - f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n" - f" CPU Wait Event: {ctx.cpu_wait_event}\n" - f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n" - f" CPU Signal Event: {ctx.cpu_signal_event}\n" - f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n") + # for ctx in _CURRENT_CONTEXT.values(): + # print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n" + # f" Stream: {ctx.stream}, ({ctx.stream.query()})\n" + # f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n" + # f" CPU Wait Event: {ctx.cpu_wait_event}\n" + # f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n" + # f" CPU Signal Event: {ctx.cpu_signal_event}\n" + # f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n") """ """ @@ -181,4 +180,4 @@ def make_ubatch_contexts( ) ctxs.append(ctx) - return ctxs, \ No newline at end of file + return ctxs \ No newline at end of file