fix using the same buffer across ubatches

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-06-09 21:03:28 +00:00
parent 642bf2dd8b
commit ef3c01c975
2 changed files with 18 additions and 12 deletions

View File

@ -31,6 +31,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.worker.ubatching import get_current_ubatch_context
has_pplx = importlib.util.find_spec("pplx_kernels") is not None has_pplx = importlib.util.find_spec("pplx_kernels") is not None
has_deepep = importlib.util.find_spec("deep_ep") is not None has_deepep = importlib.util.find_spec("deep_ep") is not None
@ -952,12 +954,12 @@ class FusedMoE(torch.nn.Module):
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels):
act_dtype = vllm_config.model_config.dtype act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size), (2, MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype, dtype=act_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros( self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts), (2, MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype, dtype=act_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
@ -1376,15 +1378,19 @@ class FusedMoE(torch.nn.Module):
chunk_size = chunk_end - chunk_start chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :] hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :]
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
assert (self.batched_hidden_states.size(0) # type: ignore assert (batched_hidden_states.size(0) # type: ignore
>= chunk_size) >= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore assert (batched_router_logits.size(0) # type: ignore
>= chunk_size) >= chunk_size)
staged_hidden_states = self.batched_hidden_states[: staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
chunk_size, :] # type: ignore staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_router_logits = self.batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)

View File

@ -134,14 +134,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send, do_recv=not send,
) )
# yield_and_switch_from_compute_to_comm_impl(schedule="default") yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send dispatch(True) # Send
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True) # print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv dispatch(False) # Recv
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True) # print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
# yield_and_switch_from_comm_to_compute_impl(schedule="default") yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize() # torch.cuda.synchronize()
if expert_x_scale is not None: if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1] expert_x_scale = expert_x_scale[:, :, 0:1]
@ -185,11 +185,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send, do_recv=not send,
) )
# yield_and_switch_from_compute_to_comm_impl(schedule="default") yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True) combine(True)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False) combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
# yield_and_switch_from_comm_to_compute_impl(schedule="default") yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize() torch.cuda.synchronize()