From e080e068ed762ba1d030a6e1177aff3f9a7c65d5 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 3 Jun 2025 18:21:17 +0000 Subject: [PATCH] fix pplx a2a Signed-off-by: Sage Moore --- .../device_communicators/all2all.py | 21 +++++++-- vllm/model_executor/layers/fused_moe/layer.py | 4 +- .../layers/fused_moe/pplx_prepare_finalize.py | 10 ++-- vllm/v1/worker/gpu_model_runner.py | 2 + vllm/v1/worker/ubatching.py | 47 ++++++++++--------- 5 files changed, 48 insertions(+), 36 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index a250ec89cd5ba..e6caf6b42fc06 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -100,13 +100,23 @@ class PPLXAll2AllManager(All2AllManagerBase): logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) - self.handle_cache = Cache() + # self.handle_cache = Cache() + self.handle_caches = [Cache(), Cache()] def get_handle(self, kwargs): import pplx_kernels as pplx - return self.handle_cache.get_or_create( + return self.handle_caches[0].get_or_create( kwargs, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode) + + def get_handles(self, kwargs): + import pplx_kernels as pplx + first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + return [first_handle, second_handle] + def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -116,9 +126,10 @@ class PPLXAll2AllManager(All2AllManagerBase): raise NotImplementedError def destroy(self): - with self.handle_cache._lock: - for _, handle in self.handle_cache._cache.items(): - handle.destroy() + for handle_cache in self.handle_caches: + with handle_cache._lock: + for _, handle in handle_cache._cache.items(): + handle.destroy() if self.internode: from pplx_kernels.nvshmem import nvshmem_finalize diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index af7b98e14c6c8..c70eb79595714 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -272,10 +272,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): group_name=all2all_manager.cpu_group.group_name, ) - handle = all2all_manager.get_handle(all_to_all_args) + handles = all2all_manager.get_handles(all_to_all_args) prepare_finalize = PplxPrepareAndFinalize( - handle, + handles, max_num_tokens=moe.max_num_tokens, world_size=all2all_manager.world_size, rank=all2all_manager.rank, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5243bbfb052fa..d36ba24188936 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -45,7 +45,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - assert False num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K ubatch_ctx = get_current_ubatch_context() @@ -128,10 +127,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 yield_and_switch_from_compute_to_comm_impl(schedule="default") dispatch(True) # Send - # torch.cuda.synchronize() + torch.cuda.synchronize() # print(f"{ubatch_id} AFTER SEND SYNC", flush=True) dispatch(False) # Recv - # torch.cuda.synchronize() + torch.cuda.synchronize() # print(f"{ubatch_id} AFTER RECV SYNC", flush=True) yield_and_switch_from_comm_to_compute_impl(schedule="default") torch.cuda.synchronize() @@ -145,7 +144,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: - assert False num_tokens = output.size(0) # M # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) @@ -177,9 +175,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): yield_and_switch_from_compute_to_comm_impl(schedule="default") combine(True) - # torch.cuda.synchronize() + torch.cuda.synchronize() # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) combine(False) # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) yield_and_switch_from_comm_to_compute_impl(schedule="default") - # torch.cuda.synchronize() + torch.cuda.synchronize() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0a2344edfb30d..736c3747084c7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1363,10 +1363,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): @torch.inference_mode() def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input): + print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) model_output = _run(token_slice, ubatch_ctx, use_dummy_input) if save_results: results.append(model_output) + print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run) -> torch.Tensor: diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 9a8b546819e43..7113541a8fb90 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -8,6 +8,7 @@ from torch.library import custom_op from vllm import forward_context from vllm.utils import current_stream +from vllm.distributed.parallel_state import get_dp_group class UBatchContext: @@ -69,28 +70,28 @@ class UBatchContext: torch.cuda.set_stream(self.current_stream) def ctx_valid_state(self): - # assert forward_context._forward_context == self.forward_context - # assert current_stream() == self.current_stream - # assert not self.cpu_wait_event.is_set() + assert forward_context._forward_context == self.forward_context + assert current_stream() == self.current_stream + assert not self.cpu_wait_event.is_set() pass def _signal_comm_done(self): - # self.ctx_valid_state() + self.ctx_valid_state() self.gpu_comm_done_event.record(self.comm_stream) def _signal_compute_done(self): - # self.ctx_valid_state() + self.ctx_valid_state() self.gpu_compute_done_event.record(self.compute_stream) def _wait_compute_done(self): # print(f"{self.id} Waiting on COMPUTE stream", flush=True) - # self.ctx_valid_state() + self.ctx_valid_state() self.comm_stream.wait_event(self.gpu_compute_done_event) # print("Compute stream done", flush=True) def _wait_comm_done(self): # print(f"{self.id} Waiting on COMM stream", flush=True) - # self.ctx_valid_state() + self.ctx_valid_state() self.compute_stream.wait_event(self.gpu_comm_done_event) # print("Comm stream done", flush=True) @@ -104,42 +105,42 @@ class UBatchContext: def _cpu_yield(self): # print(f"UBatchContext: {self.id} yielding CPU", flush=True) - # self.ctx_valid_state() + self.ctx_valid_state() self.cpu_signal_event.set() self.cpu_wait_event.wait() self.cpu_wait_event.clear() self._restore_context() - # self.ctx_valid_state() + self.ctx_valid_state() # print(f"UBatchContext: {self.id} resuming CPU", flush=True) def yield_and_switch_from_compute_to_comm(self): assert current_stream() == self.compute_stream - # dp_rank = get_dp_group().rank_in_group - # print(f"DP: {dp_rank} UB: {self.id} " - # f"Yield and switch from {self.stream_string()}", flush=True) - # self.ctx_valid_state() + dp_rank = get_dp_group().rank_in_group + print(f"DP: {dp_rank} UB: {self.id} " + f"Yield and switch from {self.stream_string()}", flush=True) + self.ctx_valid_state() self._signal_compute_done() self._cpu_yield() - # self.ctx_valid_state() + self.ctx_valid_state() assert self.current_stream == self.compute_stream self.update_stream(self.comm_stream) - # print(f"DP: {dp_rank} UB: {self.id} " - # f"Resuming on stream {self.stream_string()}", flush=True) + print(f"DP: {dp_rank} UB: {self.id} " + f"Resuming on stream {self.stream_string()}", flush=True) self._wait_compute_done() def yield_and_switch_from_comm_to_compute(self): assert current_stream() == self.comm_stream - # dp_rank = get_dp_group().rank_in_group - # print(f"DP: {dp_rank} UB: {self.id} " - # f"Yield and switch from {self.stream_string()}", flush=True) - # self.ctx_valid_state() + dp_rank = get_dp_group().rank_in_group + print(f"DP: {dp_rank} UB: {self.id} " + f"Yield and switch from {self.stream_string()}", flush=True) + self.ctx_valid_state() self._signal_comm_done() self._cpu_yield() - # self.ctx_valid_state() + self.ctx_valid_state() assert self.current_stream == self.comm_stream self.update_stream(self.compute_stream) - # print(f"DP: {dp_rank} UB: {self.id} " - # f"Resuming on stream {self.stream_string()}", flush=True) + print(f"DP: {dp_rank} UB: {self.id} " + f"Resuming on stream {self.stream_string()}", flush=True) self._wait_comm_done()