mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 22:47:11 +08:00
setup deepepll for ubatching
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
ff2dd13145
commit
a4def24c2c
@ -147,7 +147,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
self.handle_caches = [Cache(), Cache()]
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
@ -174,6 +174,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = self.handle_caches[0]
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
@ -265,7 +266,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
handle: deep_ep.Buffer = self.handle_caches[0].get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
@ -273,3 +274,10 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
def get_handles(self, kwargs):
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer)
|
||||
second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer)
|
||||
return [first_handle, second_handle]
|
||||
|
||||
@ -7,6 +7,9 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.v1.worker.ubatching import (
|
||||
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
|
||||
yield_and_switch_from_compute_to_comm_impl)
|
||||
|
||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
@ -38,7 +41,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
|
||||
|
||||
def __init__(self,
|
||||
buffer: deep_ep.Buffer,
|
||||
buffers: list[deep_ep.Buffer],
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
max_tokens_per_rank: int,
|
||||
@ -47,7 +50,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
use_fp8_dispatch: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.buffer = buffer
|
||||
self.buffers = buffers
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.quant_dtype = quant_dtype
|
||||
@ -127,9 +130,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||
f"{self.SUPPORTED_HIDDEN_SIZES}")
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||
# assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||
# (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||
# f"{self.SUPPORTED_HIDDEN_SIZES}")
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
assert hidden_size % 128 == 0, \
|
||||
@ -150,7 +156,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# Dispatch
|
||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
||||
self.buffer.low_latency_dispatch(a1,
|
||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||
rank_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
@ -168,6 +174,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool) -> None:
|
||||
|
||||
assert self.handle is not None
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||
|
||||
combine_topk_weights = topk_weights
|
||||
if apply_router_weight_on_input:
|
||||
@ -175,7 +184,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
_, event, hook = self.buffer.low_latency_combine(
|
||||
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
|
||||
fused_expert_output,
|
||||
topk_ids,
|
||||
combine_topk_weights,
|
||||
|
||||
@ -377,12 +377,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
num_global_experts=moe.num_experts,
|
||||
num_local_experts=moe.num_experts //
|
||||
all2all_manager.world_size)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
handles = all2all_manager.get_handles(all_to_all_args)
|
||||
|
||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||
# profiling. Turning it off for now.
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
handles,
|
||||
world_size=all2all_manager.world_size,
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
|
||||
@ -1467,6 +1467,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device))
|
||||
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
slice(0, num_tokens), None, False)
|
||||
|
||||
|
||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||
|
||||
def _get_model_inputs(self, tokens_slice: slice,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user