setup deepepll for ubatching

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-24 21:20:49 +00:00
parent ff2dd13145
commit a4def24c2c
4 changed files with 32 additions and 11 deletions

View File

@ -147,7 +147,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
has_deepep = importlib.util.find_spec("deep_ep") is not None
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()
self.handle_caches = [Cache(), Cache()]
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
@ -174,6 +174,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
self.handle_cache = self.handle_caches[0]
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@ -265,7 +266,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
handle: deep_ep.Buffer = self.handle_caches[0].get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
@ -273,3 +274,10 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle
def get_handles(self, kwargs):
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer)
second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer)
return [first_handle, second_handle]

View File

@ -7,6 +7,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.v1.worker.ubatching import (
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
yield_and_switch_from_compute_to_comm_impl)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
@ -38,7 +41,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
def __init__(self,
buffer: deep_ep.Buffer,
buffers: list[deep_ep.Buffer],
world_size: int,
dp_size: int,
max_tokens_per_rank: int,
@ -47,7 +50,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
use_fp8_dispatch: bool = False):
super().__init__()
self.buffer = buffer
self.buffers = buffers
self.world_size = world_size
self.dp_size = dp_size
self.quant_dtype = quant_dtype
@ -127,9 +130,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
Optional[torch.Tensor], Optional[torch.Tensor]]:
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
f"{self.SUPPORTED_HIDDEN_SIZES}")
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
# assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
# (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
# f"{self.SUPPORTED_HIDDEN_SIZES}")
if self.use_fp8_dispatch:
assert hidden_size % 128 == 0, \
@ -150,7 +156,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \
self.buffer.low_latency_dispatch(a1,
self.buffers[a2a_idx].low_latency_dispatch(a1,
rank_topk_ids,
self.max_tokens_per_rank,
num_experts,
@ -168,6 +174,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool) -> None:
assert self.handle is not None
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
combine_topk_weights = topk_weights
if apply_router_weight_on_input:
@ -175,7 +184,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine(
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_weights,

View File

@ -377,12 +377,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts //
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
handles = all2all_manager.get_handles(all_to_all_args)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
handles,
world_size=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
max_tokens_per_rank=moe.max_num_tokens,

View File

@ -1467,6 +1467,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
slice(0, num_tokens), None, False)
return input_ids, positions, inputs_embeds, intermediate_tensors
def _get_model_inputs(self, tokens_slice: slice,