mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 15:24:28 +08:00
[Feature] Shared Experts Overlap with FI deepgemm swap kernel, 2.2% throughput improvement and 3.6% TTFT improvement (#28879)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
87cbbdff63
commit
df44df0143
@ -50,6 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
prepare_finalize,
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
getattr(moe_layer, "shared_experts_stream", None),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -850,6 +850,45 @@ class FusedMoE(CustomOp):
|
||||
dp_size=get_dp_group().world_size,
|
||||
)
|
||||
|
||||
def _maybe_setup_shared_experts_stream(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
has_separate_shared_experts: bool,
|
||||
use_chunked_impl: bool,
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
use_shared_experts_stream = (
|
||||
has_separate_shared_experts
|
||||
and not use_chunked_impl
|
||||
and self.shared_experts_stream is not None
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states_clone: torch.Tensor | None = None
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We dont need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
# Mark sync start point for the separate shared experts
|
||||
# stream here since we want to run in parallel with the
|
||||
# router/gate (next op below)
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
return use_shared_experts_stream, hidden_states_clone
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
shard_id: str,
|
||||
@ -1819,36 +1858,12 @@ class FusedMoE(CustomOp):
|
||||
|
||||
use_chunked_impl = self.use_dp_chunking
|
||||
|
||||
use_shared_experts_stream = (
|
||||
has_separate_shared_experts
|
||||
and not use_chunked_impl
|
||||
and self.shared_experts_stream is not None
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
use_shared_experts_stream, hidden_states_clone = (
|
||||
self._maybe_setup_shared_experts_stream(
|
||||
hidden_states, has_separate_shared_experts, use_chunked_impl
|
||||
)
|
||||
)
|
||||
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We dont need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
# Mark sync start point for the separate shared experts
|
||||
# stream here since we want to run in parallel with the
|
||||
# router/gate (next op below)
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
# If router/gate provided, then apply it here.
|
||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||
# parallel execution of shared experts with the FusedMoE via
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id,
|
||||
@ -709,11 +710,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
shared_experts_stream: torch.cuda.Stream | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.shared_experts_stream = shared_experts_stream
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
@ -890,6 +893,34 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
def _maybe_setup_shared_experts_stream(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
# decide whether to run shared experts on a separate CUDA stream to
|
||||
# overlap with the main fused MoE kernel.
|
||||
use_shared_experts_stream = (
|
||||
self.shared_experts is not None
|
||||
and self.shared_experts_stream is not None
|
||||
and hidden_states.is_cuda
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states_clone: torch.Tensor | None = None
|
||||
if use_shared_experts_stream and self.shared_experts_stream is not None:
|
||||
# TODO: Optimize this (complicated)
|
||||
# Note: this clone adds overhead but is required
|
||||
# for correctness with multiple CUDA streams and CUDA graph capture.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
# record that the clone will be used by the separate stream so its
|
||||
# lifetime is correctly tracked.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
return use_shared_experts_stream, hidden_states_clone
|
||||
|
||||
def _prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1077,12 +1108,30 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
hidden_states_clone: torch.Tensor | None = None,
|
||||
use_shared_experts_stream: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
that handles DBO, async and shared expert overlap.
|
||||
"""
|
||||
shared_output: torch.Tensor | None = None
|
||||
|
||||
def maybe_run_shared_experts() -> torch.Tensor | None:
|
||||
if self.shared_experts is None:
|
||||
return None
|
||||
|
||||
if (
|
||||
not use_shared_experts_stream
|
||||
or self.shared_experts_stream is not None
|
||||
and (not hidden_states.is_cuda or not torch.cuda.is_available())
|
||||
):
|
||||
# fall back to running on the current stream
|
||||
return self.shared_experts(hidden_states)
|
||||
|
||||
assert hidden_states_clone is not None
|
||||
# launch shared experts on the dedicated stream.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
return self.shared_experts(hidden_states_clone)
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
@ -1095,8 +1144,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
shared_output = maybe_run_shared_experts()
|
||||
else:
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
@ -1107,8 +1155,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
shared_output = maybe_run_shared_experts()
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
@ -1131,12 +1178,28 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
receiver()
|
||||
|
||||
self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def _wait_for_shared_experts_stream(
|
||||
self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
|
||||
) -> None:
|
||||
# ensure that any work enqueued on the shared_experts_stream is
|
||||
# completed before the shared_output tensor is consumed
|
||||
if (
|
||||
self.shared_experts is not None
|
||||
and use_shared_experts_stream
|
||||
and self.shared_experts_stream is not None
|
||||
and hidden_states.is_cuda
|
||||
and current_platform.is_cuda()
|
||||
):
|
||||
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1183,6 +1246,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
use_shared_experts_stream, hidden_states_clone = (
|
||||
self._maybe_setup_shared_experts_stream(hidden_states)
|
||||
)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
@ -1219,4 +1286,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
hidden_states_clone=hidden_states_clone,
|
||||
use_shared_experts_stream=use_shared_experts_stream,
|
||||
)
|
||||
|
||||
@ -45,7 +45,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user