mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Bug] Fix DBO IMA issue for DeepEPHT (#27666)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d4aa144343
commit
b5d90f7400
@ -16,6 +16,7 @@ from vllm.utils.math_utils import round_up
|
|||||||
from vllm.v1.worker.ubatching import (
|
from vllm.v1.worker.ubatching import (
|
||||||
dbo_current_ubatch_id,
|
dbo_current_ubatch_id,
|
||||||
dbo_enabled,
|
dbo_enabled,
|
||||||
|
dbo_get_previous_event,
|
||||||
dbo_switch_to_comm,
|
dbo_switch_to_comm,
|
||||||
dbo_switch_to_compute,
|
dbo_switch_to_compute,
|
||||||
dbo_switch_to_compute_sync,
|
dbo_switch_to_compute_sync,
|
||||||
@ -110,6 +111,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# for the other ubatch before the dispatch kernel starts.
|
# for the other ubatch before the dispatch kernel starts.
|
||||||
dbo_yield_and_switch_from_compute_to_comm()
|
dbo_yield_and_switch_from_compute_to_comm()
|
||||||
|
|
||||||
|
# capture a DeepEP event and pass it as previous_event so
|
||||||
|
# DeepEP honors the dependency internally.
|
||||||
|
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||||
|
|
||||||
(
|
(
|
||||||
num_tokens_per_rank,
|
num_tokens_per_rank,
|
||||||
num_tokens_per_rdma_rank,
|
num_tokens_per_rdma_rank,
|
||||||
@ -119,7 +124,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
) = self.buffer.get_dispatch_layout(
|
) = self.buffer.get_dispatch_layout(
|
||||||
topk_idx=rank_topk_ids,
|
topk_idx=rank_topk_ids,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
previous_event=None,
|
previous_event=previous_event,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
allocate_on_comm_stream=False,
|
allocate_on_comm_stream=False,
|
||||||
)
|
)
|
||||||
@ -148,7 +153,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# to this value.
|
# to this value.
|
||||||
expert_alignment=1,
|
expert_alignment=1,
|
||||||
config=self._get_dispatch_config(),
|
config=self._get_dispatch_config(),
|
||||||
previous_event=None,
|
previous_event=previous_event,
|
||||||
async_finish=self.async_prepare and not dbo_enabled(),
|
async_finish=self.async_prepare and not dbo_enabled(),
|
||||||
allocate_on_comm_stream=False,
|
allocate_on_comm_stream=False,
|
||||||
)
|
)
|
||||||
@ -339,13 +344,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
assert fused_expert_output.dtype == torch.bfloat16, (
|
assert fused_expert_output.dtype == torch.bfloat16, (
|
||||||
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
|
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
|
||||||
)
|
)
|
||||||
|
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||||
combined_x, _, event = self.buffer.combine(
|
combined_x, _, event = self.buffer.combine(
|
||||||
# HT combine only supports BF16
|
# HT combine only supports BF16
|
||||||
x=fused_expert_output,
|
x=fused_expert_output,
|
||||||
handle=handle,
|
handle=handle,
|
||||||
topk_weights=None,
|
topk_weights=None,
|
||||||
config=self._get_combine_config(),
|
config=self._get_combine_config(),
|
||||||
previous_event=None,
|
previous_event=previous_event,
|
||||||
async_finish=do_async and not dbo_enabled(),
|
async_finish=do_async and not dbo_enabled(),
|
||||||
allocate_on_comm_stream=False,
|
allocate_on_comm_stream=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -185,6 +185,15 @@ def dbo_register_recv_hook(recv_hook):
|
|||||||
next_ctx.recv_hook = recv_hook
|
next_ctx.recv_hook = recv_hook
|
||||||
|
|
||||||
|
|
||||||
|
def dbo_get_previous_event(func, *args, **kwargs):
|
||||||
|
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||||
|
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||||
|
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
||||||
|
# execute callable on the ubatch compute stream to record/wait events there
|
||||||
|
with torch.cuda.stream(ctx.compute_stream):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_ubatch_contexts(
|
def make_ubatch_contexts(
|
||||||
num_micro_batches: int,
|
num_micro_batches: int,
|
||||||
compute_stream: torch.cuda.Stream,
|
compute_stream: torch.cuda.Stream,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user