mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 00:17:09 +08:00
setup deepepll for ubatching
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
ff2dd13145
commit
a4def24c2c
@ -147,7 +147,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
|||||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||||
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
self.handle_cache = Cache()
|
self.handle_caches = [Cache(), Cache()]
|
||||||
|
|
||||||
# This is the DeepEP default. Stick to it till we can establish
|
# This is the DeepEP default. Stick to it till we can establish
|
||||||
# reasonable defaults based on profiling.
|
# reasonable defaults based on profiling.
|
||||||
@ -174,6 +174,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
|
|
||||||
def __init__(self, cpu_group):
|
def __init__(self, cpu_group):
|
||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
|
self.handle_cache = self.handle_caches[0]
|
||||||
|
|
||||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||||
@ -265,7 +266,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
import deep_ep
|
import deep_ep
|
||||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
handle: deep_ep.Buffer = self.handle_caches[0].get_or_create(
|
||||||
buffer_kwargs, deep_ep.Buffer)
|
buffer_kwargs, deep_ep.Buffer)
|
||||||
# It is dangerous to set num sms outside this function. num_sms is not
|
# It is dangerous to set num sms outside this function. num_sms is not
|
||||||
# a part of the hash-key that identifies this object. If we are in a
|
# a part of the hash-key that identifies this object. If we are in a
|
||||||
@ -273,3 +274,10 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
# in get_or_create must be updated.
|
# in get_or_create must be updated.
|
||||||
handle.set_num_sms(self.num_sms)
|
handle.set_num_sms(self.num_sms)
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
def get_handles(self, kwargs):
|
||||||
|
import deep_ep
|
||||||
|
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||||
|
first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer)
|
||||||
|
second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer)
|
||||||
|
return [first_handle, second_handle]
|
||||||
|
|||||||
@ -7,6 +7,9 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
|
from vllm.v1.worker.ubatching import (
|
||||||
|
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
|
||||||
|
yield_and_switch_from_compute_to_comm_impl)
|
||||||
|
|
||||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||||
@ -38,7 +41,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
|
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
buffer: deep_ep.Buffer,
|
buffers: list[deep_ep.Buffer],
|
||||||
world_size: int,
|
world_size: int,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
max_tokens_per_rank: int,
|
max_tokens_per_rank: int,
|
||||||
@ -47,7 +50,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
use_fp8_dispatch: bool = False):
|
use_fp8_dispatch: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.buffer = buffer
|
self.buffers = buffers
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.quant_dtype = quant_dtype
|
self.quant_dtype = quant_dtype
|
||||||
@ -127,9 +130,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
|
||||||
hidden_size = a1.size(1)
|
hidden_size = a1.size(1)
|
||||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
f"{self.SUPPORTED_HIDDEN_SIZES}")
|
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||||
|
# assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||||
|
# (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||||
|
# f"{self.SUPPORTED_HIDDEN_SIZES}")
|
||||||
|
|
||||||
if self.use_fp8_dispatch:
|
if self.use_fp8_dispatch:
|
||||||
assert hidden_size % 128 == 0, \
|
assert hidden_size % 128 == 0, \
|
||||||
@ -150,7 +156,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
expert_x, expert_num_tokens, self.handle, event, hook = \
|
||||||
self.buffer.low_latency_dispatch(a1,
|
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||||
rank_topk_ids,
|
rank_topk_ids,
|
||||||
self.max_tokens_per_rank,
|
self.max_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
@ -168,6 +174,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool) -> None:
|
apply_router_weight_on_input: bool) -> None:
|
||||||
|
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
|
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||||
|
|
||||||
combine_topk_weights = topk_weights
|
combine_topk_weights = topk_weights
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
@ -175,7 +184,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
combine_topk_weights = torch.ones_like(topk_weights)
|
combine_topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
# TODO (varun) : Enable zero copy mode
|
# TODO (varun) : Enable zero copy mode
|
||||||
_, event, hook = self.buffer.low_latency_combine(
|
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
|
||||||
fused_expert_output,
|
fused_expert_output,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
combine_topk_weights,
|
combine_topk_weights,
|
||||||
|
|||||||
@ -377,12 +377,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
num_global_experts=moe.num_experts,
|
num_global_experts=moe.num_experts,
|
||||||
num_local_experts=moe.num_experts //
|
num_local_experts=moe.num_experts //
|
||||||
all2all_manager.world_size)
|
all2all_manager.world_size)
|
||||||
handle = all2all_manager.get_handle(all_to_all_args)
|
handles = all2all_manager.get_handles(all_to_all_args)
|
||||||
|
|
||||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||||
# profiling. Turning it off for now.
|
# profiling. Turning it off for now.
|
||||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||||
handle,
|
handles,
|
||||||
world_size=all2all_manager.world_size,
|
world_size=all2all_manager.world_size,
|
||||||
dp_size=all2all_manager.dp_world_size,
|
dp_size=all2all_manager.dp_world_size,
|
||||||
max_tokens_per_rank=moe.max_num_tokens,
|
max_tokens_per_rank=moe.max_num_tokens,
|
||||||
|
|||||||
@ -1467,6 +1467,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device))
|
device=self.device))
|
||||||
|
|
||||||
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
|
slice(0, num_tokens), None, False)
|
||||||
|
|
||||||
|
|
||||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||||
|
|
||||||
def _get_model_inputs(self, tokens_slice: slice,
|
def _get_model_inputs(self, tokens_slice: slice,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user