From ef3c01c9756ad2526d6d2f3ba37df79848738f6b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 9 Jun 2025 21:03:28 +0000 Subject: [PATCH] fix using the same buffer across ubatches Signed-off-by: Lucas Wilkinson --- vllm/model_executor/layers/fused_moe/layer.py | 22 ++++++++++++------- .../layers/fused_moe/pplx_prepare_finalize.py | 8 +++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9912eea566f93..485988fefe679 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,6 +31,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op +from vllm.v1.worker.ubatching import get_current_ubatch_context + has_pplx = importlib.util.find_spec("pplx_kernels") is not None has_deepep = importlib.util.find_spec("deep_ep") is not None @@ -952,12 +954,12 @@ class FusedMoE(torch.nn.Module): or self.moe_parallel_config.use_deepep_ll_kernels): act_dtype = vllm_config.model_config.dtype self.batched_hidden_states = torch.zeros( - (MOE_DP_CHUNK_SIZE, self.hidden_size), + (2, MOE_DP_CHUNK_SIZE, self.hidden_size), dtype=act_dtype, device=torch.cuda.current_device()) self.batched_router_logits = torch.zeros( - (MOE_DP_CHUNK_SIZE, self.global_num_experts), + (2, MOE_DP_CHUNK_SIZE, self.global_num_experts), dtype=act_dtype, device=torch.cuda.current_device()) @@ -1376,15 +1378,19 @@ class FusedMoE(torch.nn.Module): chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] + + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id + batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[batch_buffer_idx, :] - assert (self.batched_hidden_states.size(0) # type: ignore + assert (batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (batched_router_logits.size(0) # type: ignore >= chunk_size) - staged_hidden_states = self.batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = self.batched_router_logits[: - chunk_size, :] # type: ignore + staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 8583681afa802..e795c8545773a 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -134,14 +134,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_recv=not send, ) - # yield_and_switch_from_compute_to_comm_impl(schedule="default") + yield_and_switch_from_compute_to_comm_impl(schedule="default") dispatch(True) # Send # torch.cuda.synchronize() # print(f"{ubatch_id} AFTER SEND SYNC", flush=True) dispatch(False) # Recv # torch.cuda.synchronize() # print(f"{ubatch_id} AFTER RECV SYNC", flush=True) - # yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute_impl(schedule="default") # torch.cuda.synchronize() if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, 0:1] @@ -185,11 +185,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_recv=not send, ) - # yield_and_switch_from_compute_to_comm_impl(schedule="default") + yield_and_switch_from_compute_to_comm_impl(schedule="default") combine(True) # torch.cuda.synchronize() # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) combine(False) # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) - # yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute_impl(schedule="default") torch.cuda.synchronize()