[Kernels] Overlap shared experts with send/recv (#23273)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-09-03 12:35:18 -04:00 committed by GitHub
parent fa4311d85f
commit e9b92dcd89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 885 additions and 227 deletions

View File

@ -54,8 +54,8 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts
### FusedMoEPrepareAndFinalize ### FusedMoEPrepareAndFinalize
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions. The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks") ![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks")
@ -146,6 +146,10 @@ This section describes the significance of the various functions exposed by the
`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. `FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.
`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.
`FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. `FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
`FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. `FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.

View File

@ -87,6 +87,11 @@ def parse_args():
default=0.8, default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
) )
parser.add_argument(
"--compilation-config",
type=int,
help=("Compilation optimization (O) level 0-3."),
)
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",
type=str, type=str,
@ -106,6 +111,7 @@ def main(
trust_remote_code, trust_remote_code,
max_num_seqs, max_num_seqs,
max_model_len, max_model_len,
compilation_config,
gpu_memory_utilization, gpu_memory_utilization,
quantization, quantization,
): ):
@ -162,6 +168,7 @@ def main(
max_model_len=max_model_len, max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
quantization=quantization, quantization=quantization,
compilation_config=compilation_config,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
@ -218,6 +225,7 @@ if __name__ == "__main__":
args.trust_remote_code, args.trust_remote_code,
args.max_num_seqs, args.max_num_seqs,
args.max_model_len, args.max_model_len,
args.compilation_config,
args.gpu_memory_utilization, args.gpu_memory_utilization,
args.quantization, args.quantization,
), ),

View File

@ -4,10 +4,11 @@
Run `pytest tests/kernels/test_pplx_moe.py`. Run `pytest tests/kernels/test_pplx_moe.py`.
""" """
import copy
import itertools import itertools
import textwrap import textwrap
import traceback import traceback
from typing import Callable, Optional from typing import Callable, Optional, Union
import pytest import pytest
import torch import torch
@ -21,7 +22,10 @@ try:
except ImportError: except ImportError:
has_pplx = False has_pplx = False
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
_set_vllm_config)
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
naive_batched_moe)
from tests.kernels.quant_utils import dequant from tests.kernels.quant_utils import dequant
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
@ -511,7 +515,8 @@ def pplx_moe(
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
use_compile: bool = False, use_compile: bool = False,
use_cudagraphs: bool = True, use_cudagraphs: bool = True,
) -> torch.Tensor: shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = a.shape num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0] num_experts = w1.shape[0]
@ -546,6 +551,7 @@ def pplx_moe(
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
shared_experts,
) )
# Note: workers with the same dp_rank must use the exact same inputs. # Note: workers with the same dp_rank must use the exact same inputs.
@ -586,7 +592,11 @@ def pplx_moe(
global_num_experts=num_experts) global_num_experts=num_experts)
if use_cudagraphs: if use_cudagraphs:
out.fill_(0) if isinstance(out, tuple):
out[0].fill_(0)
out[1].fill_(0)
else:
out.fill_(0)
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream): with torch.cuda.graph(graph, stream=stream):
@ -626,6 +636,7 @@ def _pplx_moe(
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
use_internode: bool = False, use_internode: bool = False,
shared_experts: Optional[torch.nn.Module] = None,
): ):
try: try:
if use_internode: if use_internode:
@ -666,6 +677,11 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config): with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
if shared_experts is not None:
shared_output = shared_experts(a)
else:
shared_output = None
torch_output = torch_experts( torch_output = torch_experts(
a, a,
w1, w1,
@ -696,7 +712,7 @@ def _pplx_moe(
block_shape=block_shape, block_shape=block_shape,
) )
pplx_output = pplx_moe( pplx_outputs = pplx_moe(
group_name, group_name,
rank, rank,
world_size, world_size,
@ -713,8 +729,24 @@ def _pplx_moe(
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
shared_experts=shared_experts,
) )
if shared_experts is None:
pplx_shared_output = None
pplx_output = pplx_outputs
assert isinstance(pplx_output, torch.Tensor)
else:
pplx_shared_output, pplx_output = pplx_outputs
if shared_output is not None:
assert pplx_shared_output is not None
chunked_shared_output = chunk_by_rank(
shared_output, pgi.rank,
pgi.world_size).to(pplx_shared_output.device)
else:
chunked_shared_output = None
chunked_batch_output = chunk_by_rank( chunked_batch_output = chunk_by_rank(
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
@ -727,6 +759,15 @@ def _pplx_moe(
chunked_batch_output, chunked_batch_output,
atol=3e-2, atol=3e-2,
rtol=3e-2) rtol=3e-2)
if shared_experts is not None:
assert chunked_shared_output is not None
assert pplx_shared_output is not None
torch.testing.assert_close(pplx_shared_output,
chunked_shared_output,
atol=3e-2,
rtol=3e-2)
finally: finally:
if use_internode: if use_internode:
nvshmem_finalize() nvshmem_finalize()
@ -788,7 +829,8 @@ def test_pplx_moe_slow(
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
make_weights: bool, test_fn: Callable): use_shared_experts: bool, make_weights: bool,
test_fn: Callable):
def format_result(msg, ex=None): def format_result(msg, ex=None):
if ex is not None: if ex is not None:
@ -803,6 +845,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
else: else:
print(f"PASSED {msg}") print(f"PASSED {msg}")
if use_shared_experts:
# Note: this config is only needed for the non-naive shared experts.
new_vllm_config = copy.deepcopy(vllm_config)
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
new_vllm_config.parallel_config.enable_expert_parallel = True
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
pgi.local_rank)
current_platform.seed_everything(7) current_platform.seed_everything(7)
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
[False, True], [None, [128, 128]]) [False, True], [None, [128, 128]])
@ -819,9 +869,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
use_fp8_w8a8 = False use_fp8_w8a8 = False
quant_dtype = None quant_dtype = None
test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " test_desc = (
f"dtype={dtype}, per_act_token={per_act_token_quant}, " f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"block_shape={block_shape}") f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}, use_internode={use_internode}, "
f"use_shared_experts={use_shared_experts}")
if not use_fp8_w8a8 and (per_act_token_quant if not use_fp8_w8a8 and (per_act_token_quant
or block_shape is not None): or block_shape is not None):
@ -852,6 +904,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
args["w1_s"] = w1_s args["w1_s"] = w1_s
args["w2_s"] = w2_s args["w2_s"] = w2_s
if use_shared_experts:
args["shared_experts"] = make_shared_experts(
n,
k,
in_dtype=a.dtype,
quant_dtype=quant_dtype,
)
try: try:
test_fn( test_fn(
pgi=pgi, pgi=pgi,
@ -891,18 +951,20 @@ def test_pplx_prepare_finalize(
current_platform.seed_everything(7) current_platform.seed_everything(7)
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
use_internode, False, _pplx_prepare_finalize) use_internode, False, False, _pplx_prepare_finalize)
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.parametrize("use_internode", [False])
@pytest.mark.parametrize("use_shared_experts", [False, True])
@requires_pplx @requires_pplx
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
def test_pplx_moe( def test_pplx_moe(
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool, use_internode: bool,
use_shared_experts: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True, parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
_pplx_moe) use_shared_experts, True, _pplx_moe)

View File

@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX) FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
@ -282,3 +283,151 @@ def per_token_cast_to_fp8(
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
# CustomOp?
class BaselineMM(torch.nn.Module):
def __init__(
self,
b: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.b = b.to(dtype=torch.float32)
self.out_dtype = out_dtype
def forward(
self,
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return torch.mm(a.to(dtype=torch.float32),
self.b).to(self.out_dtype), None
class TestMLP(torch.nn.Module):
def __init__(
self,
w1: torch.Tensor,
w2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.gate_up_proj = BaselineMM(w1, out_dtype)
self.down_proj = BaselineMM(w2, out_dtype)
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
def make_naive_shared_experts(
N: int,
K: int,
in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
return TestMLP(w1, w2, out_dtype=in_dtype)
class RealMLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
w1: torch.Tensor,
w2: torch.Tensor,
hidden_act: str = "silu",
quant_config=None,
reduce_results: bool = True,
prefix: str = "",
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
) -> None:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, RowParallelLinear)
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.gate_up_proj.register_parameter(
"weight", torch.nn.Parameter(w1, requires_grad=False))
self.gate_up_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
self.gate_up_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
self.down_proj.register_parameter(
"weight", torch.nn.Parameter(w2, requires_grad=False))
self.down_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
self.down_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
def make_shared_experts(
N: int,
K: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
) -> torch.nn.Module:
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
1,
N,
K,
in_dtype=in_dtype,
quant_dtype=quant_dtype,
)
old_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(in_dtype)
if quant_dtype == torch.float8_e4m3fn:
w1 = w1[0].transpose(0, 1)
w2 = w2[0].transpose(0, 1)
w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
quant_config = Fp8Config(True)
else:
w1 = w1[0]
w2 = w2[0]
w1_s = None
w2_s = None
quant_config = None
return RealMLP(K,
N,
w1,
w2,
"silu",
quant_config,
w1_s=w1_s,
w2_s=w2_s)
finally:
torch.set_default_dtype(old_dtype)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any from typing import Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -13,11 +13,6 @@ from .base_device_communicator import All2AllManagerBase, Cache
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None
class NaiveAll2AllManager(All2AllManagerBase): class NaiveAll2AllManager(All2AllManagerBase):
""" """

View File

@ -252,7 +252,10 @@ class DeviceCommunicatorBase:
moe_modules = [ moe_modules = [
module for module in model.modules() module for module in model.modules()
if module.__class__.__name__ == "FusedMoE" # TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module) module.quant_method.init_prepare_finalize(module)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Callable, Optional, Union
import deep_ep import deep_ep
import torch import torch
@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
self.dp_size = dp_size self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset self.rank_expert_offset = rank_expert_offset
self.async_prepare = True
# The dispatch function returns a handle that the combine function # The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the # requires. We store the handle here so it is available to the
# combine function. # combine function.
@ -56,10 +58,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return None return None
return deep_ep.Buffer.get_combine_config(self.dp_size) return deep_ep.Buffer.get_combine_config(self.dp_size)
def _do_dispatch(self, tokens: torch.Tensor, def _do_dispatch(
token_scales: Optional[torch.Tensor], self,
rank_topk_ids: torch.Tensor, tokens: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int): token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> Callable:
has_scales = token_scales is not None has_scales = token_scales is not None
@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1, expert_alignment=1,
config=self._get_dispatch_config(), config=self._get_dispatch_config(),
previous_event=None, previous_event=None,
async_finish=False, async_finish=self.async_prepare,
allocate_on_comm_stream=False) allocate_on_comm_stream=False)
return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
a1_scale,
quant_config,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
expert_topk_ids: Optional[torch.Tensor],
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: Optional[torch.Tensor],
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if self.async_prepare:
event.current_stream_wait()
if has_scales: if has_scales:
expert_x, expert_x_scale = token_data expert_x, expert_x_scale = token_data
else: else:
@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset # DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns # the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces. # with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where( expert_topk_ids = torch.where(
expert_topk_ids == -1, expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0, num_experts - 1 if self.rank_expert_offset == 0 else 0,
@ -123,10 +159,28 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device) expert_num_tokens_per_expert_list, device=expert_x.device)
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def prepare( def supports_async(self) -> bool:
return True
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
@ -137,9 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> Callable:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
@ -159,37 +211,37 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
if a1q_scale is not None and a1q_scale.numel() == 1: if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1) a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, a1_post_scale = None
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
else: else:
# Dispatch and Quant a1q = a1
# DeepEP kernels only support dispatching block-quantized a1q_scale = None
# activation scales. a1_post_scale = a1_scale
# Dispatch in bfloat16
(expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, return self._do_dispatch(tokens=a1q,
expert_topk_weights) token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
def finalize( def finalize(
self, self,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import Callable, Optional, Union
import deep_ep import deep_ep
import torch import torch
@ -75,7 +75,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self, self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype, a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None], quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool, per_act_token_quant: bool,
@ -110,7 +109,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales return x, x_scales
def prepare( def supports_async(self) -> bool:
return True
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
@ -121,9 +123,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.ReceiverType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
@ -155,16 +155,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch,
async_finish=False, async_finish=False,
return_recv_hook=False) return_recv_hook=True)
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config)
def _receiver(
self,
hook: Callable,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook()
expert_x, expert_x_scale = self._do_quant( expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
expert_tokens_meta = mk.ExpertTokensMetadata( expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return (expert_x, expert_x_scale, expert_tokens_meta, None, None) return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
def finalize( def finalize(
self, self,

View File

@ -56,9 +56,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args # TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.PrepareResultType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)

View File

@ -506,9 +506,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.PrepareResultType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)

View File

@ -4,7 +4,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, overload from typing import Callable, Literal, Optional, Union, overload
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -215,6 +215,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
layer.shared_experts,
) )
def select_gemm_impl( def select_gemm_impl(
@ -252,7 +253,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
raise NotImplementedError raise NotImplementedError
@ -409,7 +410,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
assert expert_load_view is not None assert expert_load_view is not None
assert logical_to_physical_map is not None assert logical_to_physical_map is not None
@ -461,7 +462,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
@ -547,7 +548,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
): ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \ if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \ logical_to_physical_map is not None or \
logical_replica_count is not None: logical_replica_count is not None:
@ -594,7 +595,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
): ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \ if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \ logical_to_physical_map is not None or \
logical_replica_count is not None: logical_replica_count is not None:
@ -633,7 +634,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
@ -948,6 +949,10 @@ class FusedMoE(CustomOp):
dtype=moe.in_dtype, dtype=moe.in_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
@ -1400,6 +1405,7 @@ class FusedMoE(CustomOp):
return [ return [
weight.view(self.local_num_experts, -1) for name, weight in weights weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS if name not in NON_EXPERT_WEIGHTS
and not name.startswith("_shared_experts.")
] ]
def set_eplb_state( def set_eplb_state(
@ -1582,25 +1588,45 @@ class FusedMoE(CustomOp):
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor, def forward(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
og_hidden_states = hidden_states.shape[-1] og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states: if self.hidden_size != og_hidden_states:
hidden_states = F.pad(hidden_states, hidden_states = F.pad(hidden_states,
(0, self.hidden_size - og_hidden_states), (0, self.hidden_size - og_hidden_states),
mode='constant', mode='constant',
value=0.0) value=0.0)
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(
hidden_states, router_logits,
self.layer_name)[..., :og_hidden_states]
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, if self.shared_experts is None:
full_router_logits: torch.Tensor): if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple)
else:
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name)
return fused_output[..., :og_hidden_states]
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, self.layer_name)
return (shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states])
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.batched_hidden_states is not None assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype assert self.batched_hidden_states.dtype == full_hidden_states.dtype
@ -1611,7 +1637,10 @@ class FusedMoE(CustomOp):
assert ( assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1)) self.batched_router_logits.size(-1) == full_router_logits.size(-1))
full_final_hidden_states = torch.empty_like(full_hidden_states) full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(
full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False): def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start chunk_size = chunk_end - chunk_start
@ -1652,9 +1681,21 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
assert self.shared_experts is None or isinstance(
final_hidden_states, tuple)
if not skip_result_store: if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_( if self.shared_experts is None:
final_hidden_states, non_blocking=True) full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states,
non_blocking=True)
else:
full_shared_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[0],
non_blocking=True)
full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[1],
non_blocking=True)
ctx = get_forward_context() ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP # flashinfer_cutlass_kernels can handle: optional DP + TP/EP
@ -1675,10 +1716,17 @@ class FusedMoE(CustomOp):
chunk_end, chunk_end,
skip_result_store=chunk_start_ >= num_tokens) skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states if self.shared_experts is None:
return full_fused_final_hidden_states
else:
return (full_shared_final_hidden_states,
full_fused_final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.quant_method is not None assert self.quant_method is not None
# Route to the chunked forward path using the FlashInfer Cutlass kernel # Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled. # only when data parallelism (DP) is enabled.
@ -1698,6 +1746,15 @@ class FusedMoE(CustomOp):
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if (not isinstance(self.quant_method.fused_experts,
FusedMoEModularKernel)
and self.shared_experts is not None):
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
@ -1722,14 +1779,30 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
if do_naive_dispatch_combine: if shared_output is not None:
final_hidden_states = get_ep_group().combine(final_hidden_states) assert not isinstance(final_hidden_states, tuple)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): assert self.shared_experts is not None
# Default set to False. (May have to add shared expert outputs. final_hidden_states = (
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( shared_output,
final_hidden_states) final_hidden_states,
)
return final_hidden_states def reduce_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return states
if self.shared_experts is None:
return reduce_output(final_hidden_states)
else:
return (
reduce_output(final_hidden_states[0]),
reduce_output(final_hidden_states[1]),
)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
@ -1784,17 +1857,22 @@ class FusedMoE(CustomOp):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(
layer_name: str) -> torch.Tensor: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(
layer_name: str) -> torch.Tensor: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
@ -1807,6 +1885,37 @@ direct_register_custom_op(
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters # Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code # to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Optional, final from typing import Callable, Optional, Union, final
import torch import torch
@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC):
raise NotImplementedError raise NotImplementedError
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType = tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]
ReceiverType = Callable[[], PrepareResultType]
# TODO: pass FusedMoEParallelConfig in as ctor parameter? # TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[ ) -> PrepareResultType:
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
""" """
Perform any quantization (and/or) dispatching needed Perform any quantization (and/or) dispatching needed for this kernel.
for this kernel.
- a1: The (unquantized) input to the MoE layer. - a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1 - a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make - a2_scale: Optional scales for the second MoE gemm. Required to make
@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async.
"""
return False
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
- expert_map: A tensor mapping expert indices from the global expert
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `prepare`, e.g.
receiver = obj.prepare_async(...)
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
is equivalent to:
a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def finalize( def finalize(
self, self,
@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts
assert prepare_finalize.activation_format == \ assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], ( fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}." f"{prepare_finalize.__class__.__name__}."
@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism. of weights, w1 and w2, and top-k gating mechanism.
@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, shared_output: torch.Tensor
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, if (not self.prepare_finalize.supports_async()
a1_scale, or self.shared_experts is None):
a2_scale,
topk_weights, # Run shared experts serially with dispatch.
topk_ids, if self.shared_experts is not None:
global_num_experts, shared_output = self.shared_experts(a1)
expert_map,
apply_router_weight_on_input, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
self.fused_experts.quant_config, _expert_topk_weights) = self.prepare_finalize.prepare(
) a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
else:
# Overlap shared expert compute with all2all dispatch.
receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
assert self.shared_experts is not None
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks. # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
return output if self.shared_experts is None:
return output
else:
return shared_output, output

View File

@ -84,12 +84,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.max_num_tokens return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]: def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int32 return torch.uint32
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
def prepare( def supports_async(self) -> bool:
return True
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
@ -100,9 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.ReceiverType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
@ -138,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
_validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant,
quant_config.block_shape) quant_config.block_shape)
orig_a_scale_block_shape: Optional[int] = None
if a1q_scale is not None: if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1 scalar_scales = a1q_scale.numel() == 1
@ -205,8 +208,45 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
out_expert_x_scale=expert_x_scale, out_expert_x_scale=expert_x_scale,
dp_x=a1q, dp_x=a1q,
dp_x_scale=a1q_scale, dp_x_scale=a1q_scale,
indices=topk_ids.view(dtype=torch.uint32), indices=topk_ids,
bound_m=bound_m, bound_m=bound_m,
do_send=True,
do_recv=False,
)
return lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
a1q,
a1q_scale,
topk_ids,
bound_m,
orig_a_scale_block_shape,
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: Optional[torch.Tensor],
a1q: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor,
bound_m: Optional[torch.Tensor],
orig_a_scale_block_shape: Optional[int],
) -> mk.PrepareResultType:
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
) )
if expert_x_scale is not None: if expert_x_scale is not None:
@ -218,6 +258,31 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
return receiver()
def finalize( def finalize(
self, self,
output: torch.Tensor, output: torch.Tensor,

View File

@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.PrepareResultType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
@ -505,7 +505,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:

View File

@ -474,7 +474,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert self.fused_experts is None assert self.fused_experts is None

View File

@ -3,7 +3,7 @@
import enum import enum
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional, Union
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
@ -358,7 +358,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:
@ -819,7 +819,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for " "EPLB not supported for "
@ -1069,7 +1069,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:
@ -1375,7 +1375,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:
@ -1608,7 +1608,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import torch import torch
@ -128,7 +128,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -988,7 +988,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
assert expert_load_view is not None assert expert_load_view is not None
assert logical_to_physical_map is not None assert logical_to_physical_map is not None

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import gguf import gguf
import torch import torch
@ -540,7 +540,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
): ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:

View File

@ -654,7 +654,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:

View File

@ -491,7 +491,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.") "EPLB not supported for `ModelOptFp8MoEMethod` yet.")
@ -1366,7 +1366,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
): ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import torch import torch
@ -305,7 +305,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional from typing import Callable, Optional, Union
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -554,7 +554,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import torch import torch
@ -226,7 +226,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:
@ -390,7 +390,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:

View File

@ -3,7 +3,7 @@
# Copyright © 2025, Oracle and/or its affiliates. # Copyright © 2025, Oracle and/or its affiliates.
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -291,7 +291,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:

View File

@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import (
SharedFusedMoE)
__all__ = ["SharedFusedMoE"]

View File

@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return self._shared_experts if self.use_overlapped else None
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (self.reduce_results and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()):
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out

View File

@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
@ -147,63 +148,85 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_end = (self.physical_expert_start + self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts) self.n_local_physical_experts)
self.experts = FusedMoE( if config.n_shared_experts is None:
num_experts=config.n_routed_experts, self.experts = FusedMoE(
top_k=config.num_experts_per_tok, num_experts=config.n_routed_experts,
hidden_size=config.hidden_size, top_k=config.num_experts_per_tok,
intermediate_size=config.moe_intermediate_size, hidden_size=config.hidden_size,
reduce_results=False, intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob, reduce_results=False,
quant_config=quant_config, renormalize=config.norm_topk_prob,
use_grouped_topk=True, quant_config=quant_config,
num_expert_group=config.n_group, use_grouped_topk=True,
topk_group=config.topk_group, num_expert_group=config.n_group,
prefix=f"{prefix}.experts", topk_group=config.topk_group,
scoring_func=config.scoring_func, prefix=f"{prefix}.experts",
# we do scaling outside, set factor to 1.0 to avoid double mul scoring_func=config.scoring_func,
routed_scaling_factor=1.0, # we do scaling outside, set factor to 1.0 to avoid double mul
e_score_correction_bias=self.gate.e_score_correction_bias, routed_scaling_factor=1.0,
enable_eplb=self.enable_eplb, e_score_correction_bias=self.gate.e_score_correction_bias,
num_redundant_experts=self.n_redundant_experts) enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None: self.shared_experts = None
else:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs( reduce_results=False,
),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16: fused_moe_out = self.experts(hidden_states=hidden_states,
final_hidden_states = self.experts( router_logits=router_logits)
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out
else: else:
# Fix FP16 overflow shared_output = None
# See DeepseekV2DecoderLayer for more details. final_hidden_states = fused_moe_out
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) # Fix FP16 overflow
if shared_output is not None: # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states *= self.routed_scaling_factor
else: elif self.shared_experts is not None:
# Fix FP16 overflow assert shared_output is not None
# See DeepseekV2DecoderLayer for more details. shared_output *= (1. / self.routed_scaling_factor)
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = ( final_hidden_states = (

View File

@ -184,6 +184,8 @@ class Glm4MoE(nn.Module):
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
router_logits = self.gate(hidden_states.to(dtype=torch.float32)) router_logits = self.gate(hidden_states.to(dtype=torch.float32))
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,

View File

@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
quant_config=None, quant_config=None,
prefix=f"{prefix}.router") prefix=f"{prefix}.router")
self.experts = FusedMoE( self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False,
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
reduce_results=False, reduce_results=False,
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts") prefix=f"{prefix}.experts",
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
) )
def forward(self, hidden_states): def forward(self, hidden_states):
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts( shared_out, routed_out = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )

View File

@ -500,7 +500,8 @@ class Worker(WorkerBase):
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
moe_modules = [ moe_modules = [
module for module in self.model_runner.model.modules() module for module in self.model_runner.model.modules()
if module.__class__.__name__ == "FusedMoE" if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
] ]
num_local_experts = moe_modules[0].moe_config.num_local_experts num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(module.moe_config.num_local_experts == num_local_experts assert all(module.moe_config.num_local_experts == num_local_experts