diff --git a/vllm/config.py b/vllm/config.py index f6e40d8b8edcf..a74821f28140e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1937,6 +1937,9 @@ class ParallelConfig: request is greater than this threshold, microbatching will be used. Otherwise, the request will be processed in a single batch.""" + enable_async_comms: bool = False + """enable async comms""" + ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 51f8e433b1b94..7dc3de62243cc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -304,6 +304,7 @@ class EngineArgs: data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_microbatching: bool = ParallelConfig.enable_microbatching + enable_async_comms: bool = ParallelConfig.enable_async_comms enable_eplb: bool = ParallelConfig.enable_eplb num_redundant_experts: int = ParallelConfig.num_redundant_experts eplb_window_size: int = ParallelConfig.eplb_window_size @@ -640,6 +641,8 @@ class EngineArgs: **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument("--enable-microbatching", **parallel_kwargs["enable_microbatching"]) + parallel_group.add_argument("--enable-async-comms", + **parallel_kwargs["enable_async_comms"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--num-redundant-experts", @@ -1166,6 +1169,7 @@ class EngineArgs: data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_microbatching=self.enable_microbatching, + enable_async_comms=self.enable_async_comms, enable_eplb=self.enable_eplb, num_redundant_experts=self.num_redundant_experts, eplb_window_size=self.eplb_window_size, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index be7913a9cab96..30a0b5801f71f 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -127,6 +127,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): hidden_size = a1.size(1) ubatch_ctx = get_current_ubatch_context() a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0 + do_recv_hook = True if ubatch_ctx is not None and \ + ubatch_ctx.enable_async_comms else False if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ @@ -147,16 +149,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # Dispatch yield_and_switch_from_compute_to_comm(schedule="default") - expert_x, expert_num_tokens, handle, _, _= \ + expert_x, expert_num_tokens, handle, _, recv_hook= \ self.buffers[a2a_idx].low_latency_dispatch(a1, topk_ids, self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, - return_recv_hook=False) + return_recv_hook=do_recv_hook) self.handles[a2a_idx] = handle - yield_and_switch_from_comm_to_compute(schedule="default") + yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook) expert_x, expert_x_scale = self._do_quant( expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, @@ -178,6 +180,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ubatch_ctx = get_current_ubatch_context() a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0 handle = self.handles[a2a_idx] + do_recv_hook = True if ubatch_ctx is not None and \ + ubatch_ctx.enable_async_comms else False assert handle is not None combine_topk_weights = topk_weights @@ -187,12 +191,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # TODO (varun) : Enable zero copy mode yield_and_switch_from_compute_to_comm(schedule="default") - _ = self.buffers[a2a_idx].low_latency_combine(fused_expert_output, + _, _, recv_hook = self.buffers[a2a_idx].low_latency_combine(fused_expert_output, topk_ids, combine_topk_weights, handle, async_finish=False, zero_copy=False, - return_recv_hook=False, + return_recv_hook=do_recv_hook, out=output) - yield_and_switch_from_comm_to_compute(schedule="default") + yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f96354ac578ca..28f26f93c50ac 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1656,7 +1656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): comm_stream=self.comm_stream, compute_stream=compute_stream, forward_contexts=forward_contexts, - device=self.device) + device=self.device, + enable_async_comms=self.parallel_config.enable_async_comms) ubatch_metadata: list[UbatchMetadata] = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 433cf28bef1c4..f6173d353a9e1 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -24,6 +24,7 @@ class UBatchContext: cpu_signal_event: threading.Event, gpu_comm_done_event: torch.cuda.Event, gpu_compute_done_event: torch.cuda.Event, + enable_async_comms: bool, schedule: str = "default"): self.id = id self.comm_stream = comm_stream @@ -34,6 +35,7 @@ class UBatchContext: self.current_stream = compute_stream self.gpu_comm_done_event = gpu_comm_done_event self.gpu_compute_done_event = gpu_compute_done_event + self.enable_async_comms = enable_async_comms self.schedule = schedule def __enter__(self): @@ -106,10 +108,14 @@ class UBatchContext: self.update_stream(self.comm_stream) self._wait_compute_done() - def yield_and_switch_from_comm_to_compute(self): + def yield_and_switch_from_comm_to_compute(self, recv_hook = None): assert current_stream() == self.comm_stream - self._signal_comm_done() + if recv_hook is None: + self._signal_comm_done() self._cpu_yield() + if recv_hook is not None: + recv_hook() + self._signal_comm_done() assert self.current_stream == self.comm_stream self.update_stream(self.compute_stream) self._wait_comm_done() @@ -133,11 +139,11 @@ def yield_and_switch_from_compute_to_comm(schedule="default"): ctx.yield_and_switch_from_compute_to_comm() -def yield_and_switch_from_comm_to_compute(schedule="default"): +def yield_and_switch_from_comm_to_compute(schedule="default", recv_hook = None): # Perform the barrier if a context exists for this thread ctx = get_current_ubatch_context() if ctx is not None and ctx.schedule == schedule: - ctx.yield_and_switch_from_comm_to_compute() + ctx.yield_and_switch_from_comm_to_compute(recv_hook=recv_hook) def make_ubatch_contexts( @@ -146,6 +152,7 @@ def make_ubatch_contexts( comm_stream: torch.cuda.Stream, forward_contexts: list[ForwardContext], device: Optional[torch.device] = None, + enable_async_comms: bool = False, schedule: str = "default", ) -> list[UBatchContext]: assert num_micro_batches == 2, "only been tested with 2 micro-batches" @@ -175,6 +182,7 @@ def make_ubatch_contexts( num_micro_batches], gpu_comm_done_event=gpu_comm_done_events[i], gpu_compute_done_event=gpu_compute_done_events[i], + enable_async_comms=enable_async_comms, schedule=schedule) ctxs.append(ctx)