mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 01:37:53 +08:00
fix using the same buffer across ubatches
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
642bf2dd8b
commit
ef3c01c975
@ -31,6 +31,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
from vllm.v1.worker.ubatching import get_current_ubatch_context
|
||||||
|
|
||||||
|
|
||||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||||
@ -952,12 +954,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||||
act_dtype = vllm_config.model_config.dtype
|
act_dtype = vllm_config.model_config.dtype
|
||||||
self.batched_hidden_states = torch.zeros(
|
self.batched_hidden_states = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
(2, MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||||
dtype=act_dtype,
|
dtype=act_dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
self.batched_router_logits = torch.zeros(
|
self.batched_router_logits = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
(2, MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||||
dtype=act_dtype,
|
dtype=act_dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
@ -1377,14 +1379,18 @@ class FusedMoE(torch.nn.Module):
|
|||||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||||
|
|
||||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
|
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||||
|
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
|
||||||
|
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
|
||||||
|
|
||||||
|
assert (batched_hidden_states.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
assert (self.batched_router_logits.size(0) # type: ignore
|
assert (batched_router_logits.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
staged_hidden_states = self.batched_hidden_states[:
|
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
|
||||||
chunk_size, :] # type: ignore
|
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
|
||||||
staged_router_logits = self.batched_router_logits[:
|
|
||||||
chunk_size, :] # type: ignore
|
|
||||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||||
|
|
||||||
|
|||||||
@ -134,14 +134,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
do_recv=not send,
|
do_recv=not send,
|
||||||
)
|
)
|
||||||
|
|
||||||
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
dispatch(True) # Send
|
dispatch(True) # Send
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||||
dispatch(False) # Recv
|
dispatch(False) # Recv
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||||
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
if expert_x_scale is not None:
|
if expert_x_scale is not None:
|
||||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||||
@ -185,11 +185,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
do_recv=not send,
|
do_recv=not send,
|
||||||
)
|
)
|
||||||
|
|
||||||
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
combine(True)
|
combine(True)
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||||
combine(False)
|
combine(False)
|
||||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||||
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user