mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[Bug] Fix DBO IMA issue for DeepEPHT (#27666)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d4aa144343
commit
b5d90f7400
@ -16,6 +16,7 @@ from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id,
|
||||
dbo_enabled,
|
||||
dbo_get_previous_event,
|
||||
dbo_switch_to_comm,
|
||||
dbo_switch_to_compute,
|
||||
dbo_switch_to_compute_sync,
|
||||
@ -110,6 +111,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# for the other ubatch before the dispatch kernel starts.
|
||||
dbo_yield_and_switch_from_compute_to_comm()
|
||||
|
||||
# capture a DeepEP event and pass it as previous_event so
|
||||
# DeepEP honors the dependency internally.
|
||||
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
@ -119,7 +124,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
) = self.buffer.get_dispatch_layout(
|
||||
topk_idx=rank_topk_ids,
|
||||
num_experts=num_experts,
|
||||
previous_event=None,
|
||||
previous_event=previous_event,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False,
|
||||
)
|
||||
@ -148,7 +153,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# to this value.
|
||||
expert_alignment=1,
|
||||
config=self._get_dispatch_config(),
|
||||
previous_event=None,
|
||||
previous_event=previous_event,
|
||||
async_finish=self.async_prepare and not dbo_enabled(),
|
||||
allocate_on_comm_stream=False,
|
||||
)
|
||||
@ -339,13 +344,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
assert fused_expert_output.dtype == torch.bfloat16, (
|
||||
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
|
||||
)
|
||||
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||
combined_x, _, event = self.buffer.combine(
|
||||
# HT combine only supports BF16
|
||||
x=fused_expert_output,
|
||||
handle=handle,
|
||||
topk_weights=None,
|
||||
config=self._get_combine_config(),
|
||||
previous_event=None,
|
||||
previous_event=previous_event,
|
||||
async_finish=do_async and not dbo_enabled(),
|
||||
allocate_on_comm_stream=False,
|
||||
)
|
||||
|
||||
@ -185,6 +185,15 @@ def dbo_register_recv_hook(recv_hook):
|
||||
next_ctx.recv_hook = recv_hook
|
||||
|
||||
|
||||
def dbo_get_previous_event(func, *args, **kwargs):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
||||
# execute callable on the ubatch compute stream to record/wait events there
|
||||
with torch.cuda.stream(ctx.compute_stream):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user