fix using the same buffer across ubatches

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-06-09 21:03:28 +00:00
parent 642bf2dd8b
commit ef3c01c975
2 changed files with 18 additions and 12 deletions

View File

@ -31,6 +31,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
from vllm.v1.worker.ubatching import get_current_ubatch_context
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
has_deepep = importlib.util.find_spec("deep_ep") is not None
@ -952,12 +954,12 @@ class FusedMoE(torch.nn.Module):
or self.moe_parallel_config.use_deepep_ll_kernels):
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
(2, MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
(2, MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
device=torch.cuda.current_device())
@ -1376,15 +1378,19 @@ class FusedMoE(torch.nn.Module):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
assert (self.batched_hidden_states.size(0) # type: ignore
assert (batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
assert (batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore
staged_router_logits = self.batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)

View File

@ -134,14 +134,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
@ -185,11 +185,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize()