diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cd7b1a2c1cfa6..8461e8e1a8427 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -307,13 +307,23 @@ class AllToAllCache: return instance -# Global singleton -_all_to_all_cache = AllToAllCache() +from typing import List +_all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()] -# Factory function as a cleaner interface -def get_all_to_all(**kwargs): - return _all_to_all_cache.get_or_create(**kwargs) +# Factory function with cache ID support +def get_all_to_all(cache_id: int, **kwargs): + """Get or create an AllToAll instance from the specified cache. + + Args: + cache_id: Integer ID of the cache to use (0 or 1) + **kwargs: Arguments passed to AllToAll creation + + Returns: + AllToAll instance from the specified cache + """ + assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}" + return _all_to_all_cache[cache_id].get_or_create(**kwargs) @CustomOp.register("unquantized_fused_moe") @@ -692,25 +702,26 @@ def _construct_prepare_finalize( if moe.use_pplx_kernels: logger.debug("using PplxPrepareAndFinalize") - - all_to_all = get_all_to_all( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else + + kwargs = { + "max_num_tokens" :max_num_tokens, + "num_experts" :moe.num_experts, + "experts_per_token" :moe.experts_per_token, # topk + "rank" :rank, + "world_size" :world_size, + "dp_size" :dp_size, + "hidden_dim":moe.hidden_dim, + "hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize, + "hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else ((moe.hidden_dim + moe.block_size - 1) // - moe.block_size * torch.float32.itemsize))) + moe.block_size * torch.float32.itemsize)), + } + + + a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)] return PplxPrepareAndFinalize( - all_to_all, + a2as, max_num_tokens=max_num_tokens, world_size=world_size, rank=rank, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 07f5b2bfba998..87d2745c20eb4 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -19,7 +19,7 @@ from vllm.v1.worker.ubatching import ( class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__(self, - a2a: pplx.AllToAll, + a2as: list[pplx.AllToAll], max_num_tokens: int, world_size: int, rank: int, @@ -28,7 +28,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): block_shape: Optional[list[int]] = None): super().__init__() assert max_num_tokens > 0 - self.a2a = a2a + self.a2as = a2as self.block_shape = block_shape self.max_num_tokens = max_num_tokens self.world_size = world_size @@ -49,6 +49,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id assert rank_topk_ids.size(0) == num_tokens # assert expert_map is None, "NYI" @@ -110,7 +113,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): bound_m: Optional[torch.Tensor] = None def dispatch(send: bool): - self.a2a.dispatch( + self.a2as[a2a_idx].dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, out_expert_x_scale=expert_x_scale, @@ -122,9 +125,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_recv=not send, ) + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 yield_and_switch_from_compute_to_comm_impl(schedule="default") dispatch(True) # Send + torch.cuda.synchronize() + # print(f"{ubatch_id} AFTER SEND SYNC", flush=True) dispatch(False) # Recv + # torch.cuda.synchronize() + # print(f"{ubatch_id} AFTER RECV SYNC", flush=True) yield_and_switch_from_comm_to_compute_impl(schedule="default") return expert_x, expert_x_scale, expert_num_tokens @@ -141,6 +150,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) bound_m: Optional[torch.Tensor] = None + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id assert topk_ids.size(0) == num_tokens, ( f"{topk_ids.size(0)} == {num_tokens}") @@ -153,7 +165,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_weights = torch.ones_like(topk_weights) def combine(send: bool): - self.a2a.combine( + self.a2as[a2a_idx].combine( out_tokens=output, indices=topk_ids, weights=topk_weights, @@ -162,8 +174,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_send=send, do_recv=not send, ) - yield_and_switch_from_compute_to_comm_impl(schedule="default") combine(True) + torch.cuda.synchronize() + # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) combine(False) + # torch.cuda.synchronize() + # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) yield_and_switch_from_comm_to_compute_impl(schedule="default") \ No newline at end of file diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 5bc8a8e31def7..e17d57ee6aeeb 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,6 +7,7 @@ import os from typing import Optional from torch.library import Library from torch.library import custom_op, register_kernel +from vllm.distributed import (get_dp_group) from vllm.utils import current_stream from vllm import forward_context @@ -53,7 +54,7 @@ class UBatchContext: def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = None - print("Finishing ubatch %d\n" % self.id) + print("Finishing ubatch %d\n" % self.id, flush=True) self.cpu_signal_event.set() self.cpu_wait_event.clear() self.current_stream = self.compute_stream @@ -81,49 +82,63 @@ class UBatchContext: self.gpu_compute_done_event.record(self.compute_stream) def _wait_compute_done(self): - print("Waiting on compute stream") + # print(f"{self.id} Waiting on COMPUTE stream", flush=True) self.ctx_valid_state() self.comm_stream.wait_event(self.gpu_compute_done_event) - print("Compute stream done") + # print("Compute stream done", flush=True) def _wait_comm_done(self): - print("Waiting on comm stream") + # print(f"{self.id} Waiting on COMM stream", flush=True) self.ctx_valid_state() self.compute_stream.wait_event(self.gpu_comm_done_event) - print("Comm stream done") + # print("Comm stream done", flush=True) + + def stream_string(self): + if current_stream() == self.compute_stream: + assert self.current_stream == self.compute_stream + return "COMPUTE" + elif current_stream() == self.comm_stream: + assert self.current_stream == self.comm_stream + return "COMM" def _cpu_yield(self): - print("UBatchContext: %d yielding CPU\n" % self.id) + # print(f"UBatchContext: {self.id} yielding CPU", flush=True) self.ctx_valid_state() self.cpu_signal_event.set() self.cpu_wait_event.wait() self.cpu_wait_event.clear() self._restore_context() self.ctx_valid_state() - print("UBatchContext: %d resuming CPU\n" % self.id) + # print(f"UBatchContext: {self.id} resuming CPU", flush=True) def yield_and_switch_from_compute_to_comm(self): - print("Yield and switch from compute") + assert current_stream() == self.compute_stream + dp_rank = get_dp_group().rank_in_group + print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) self.ctx_valid_state() - self._signal_compute_done() + # self._signal_compute_done() self._cpu_yield() self.ctx_valid_state() assert self.current_stream == self.compute_stream self.update_stream(self.comm_stream) - self._wait_compute_done() + print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) + # self._wait_compute_done() def yield_and_switch_from_comm_to_compute(self): + assert current_stream() == self.comm_stream + dp_rank = get_dp_group().rank_in_group + print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) self.ctx_valid_state() - self._signal_comm_done() + # self._signal_comm_done() self._cpu_yield() self.ctx_valid_state() assert self.current_stream == self.comm_stream self.update_stream(self.compute_stream) - self._wait_comm_done() + print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) + # self._wait_comm_done() _CURRENT_CONTEXT: dict = {} - def get_current_ubatch_context() -> Optional[UBatchContext]: global _CURRENT_CONTEXT """