[Bug] Fix DBO IMA issue for DeepEPHT (#27666)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-10-29 16:28:27 -04:00 committed by GitHub
parent d4aa144343
commit b5d90f7400
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 3 deletions

View File

@ -16,6 +16,7 @@ from vllm.utils.math_utils import round_up
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
dbo_get_previous_event,
dbo_switch_to_comm,
dbo_switch_to_compute,
dbo_switch_to_compute_sync,
@ -110,6 +111,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm()
# capture a DeepEP event and pass it as previous_event so
# DeepEP honors the dependency internally.
previous_event = dbo_get_previous_event(self.buffer.capture)
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
@ -119,7 +124,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
previous_event=previous_event,
async_finish=False,
allocate_on_comm_stream=False,
)
@ -148,7 +153,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# to this value.
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
previous_event=previous_event,
async_finish=self.async_prepare and not dbo_enabled(),
allocate_on_comm_stream=False,
)
@ -339,13 +344,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert fused_expert_output.dtype == torch.bfloat16, (
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
)
previous_event = dbo_get_previous_event(self.buffer.capture)
combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output,
handle=handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
previous_event=previous_event,
async_finish=do_async and not dbo_enabled(),
allocate_on_comm_stream=False,
)

View File

@ -185,6 +185,15 @@ def dbo_register_recv_hook(recv_hook):
next_ctx.recv_hook = recv_hook
def dbo_get_previous_event(func, *args, **kwargs):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
ctx = _CURRENT_CONTEXTS[ctx_idx]
# execute callable on the ubatch compute stream to record/wait events there
with torch.cuda.stream(ctx.compute_stream):
return func(*args, **kwargs)
def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,