mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 06:37:03 +08:00
fix using the same buffer across ubatches
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
642bf2dd8b
commit
ef3c01c975
@ -31,6 +31,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.worker.ubatching import get_current_ubatch_context
|
||||
|
||||
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
@ -952,12 +954,12 @@ class FusedMoE(torch.nn.Module):
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
act_dtype = vllm_config.model_config.dtype
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||
(2, MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||
dtype=act_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||
(2, MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||
dtype=act_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@ -1376,15 +1378,19 @@ class FusedMoE(torch.nn.Module):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
|
||||
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
|
||||
|
||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||
assert (batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
assert (batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
staged_hidden_states = self.batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_router_logits = self.batched_router_logits[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
|
||||
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
|
||||
@ -134,14 +134,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
dispatch(True) # Send
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||
dispatch(False) # Recv
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
# torch.cuda.synchronize()
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||
@ -185,11 +185,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
combine(True)
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||
combine(False)
|
||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user