mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 19:27:07 +08:00
yields now work with deepep_ll
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
a4def24c2c
commit
930efd02ab
@ -60,7 +60,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handle = None
|
||||
self.handles: list[Optional[tuple]] = [None, None]
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return self.max_tokens_per_rank
|
||||
@ -155,7 +155,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
expert_x, expert_num_tokens, handle, event, hook = \
|
||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||
rank_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
@ -163,6 +164,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
self.handles[a2a_idx] = handle
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
||||
a1.dtype)
|
||||
@ -173,10 +176,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> None:
|
||||
|
||||
assert self.handle is not None
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||
handle = self.handles[a2a_idx]
|
||||
assert handle is not None
|
||||
|
||||
combine_topk_weights = topk_weights
|
||||
if apply_router_weight_on_input:
|
||||
@ -184,12 +188,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
|
||||
fused_expert_output,
|
||||
topk_ids,
|
||||
combine_topk_weights,
|
||||
self.handle,
|
||||
handle,
|
||||
async_finish=False,
|
||||
zero_copy=False,
|
||||
return_recv_hook=False,
|
||||
out=output)
|
||||
# event.current_stream_wait()
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user