mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-19 15:17:16 +08:00
[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
fa4311d85f
commit
e9b92dcd89
@ -54,8 +54,8 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts
|
||||
|
||||
### FusedMoEPrepareAndFinalize
|
||||
|
||||
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions.
|
||||
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
|
||||
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
|
||||
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
|
||||
|
||||

|
||||
|
||||
@ -146,6 +146,10 @@ This section describes the significance of the various functions exposed by the
|
||||
|
||||
`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.
|
||||
|
||||
@ -87,6 +87,11 @@ def parse_args():
|
||||
default=0.8,
|
||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compilation-config",
|
||||
type=int,
|
||||
help=("Compilation optimization (O) level 0-3."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
type=str,
|
||||
@ -106,6 +111,7 @@ def main(
|
||||
trust_remote_code,
|
||||
max_num_seqs,
|
||||
max_model_len,
|
||||
compilation_config,
|
||||
gpu_memory_utilization,
|
||||
quantization,
|
||||
):
|
||||
@ -162,6 +168,7 @@ def main(
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
quantization=quantization,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
@ -218,6 +225,7 @@ if __name__ == "__main__":
|
||||
args.trust_remote_code,
|
||||
args.max_num_seqs,
|
||||
args.max_model_len,
|
||||
args.compilation_config,
|
||||
args.gpu_memory_utilization,
|
||||
args.quantization,
|
||||
),
|
||||
|
||||
@ -4,10 +4,11 @@
|
||||
|
||||
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||
"""
|
||||
import copy
|
||||
import itertools
|
||||
import textwrap
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -21,7 +22,10 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
|
||||
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
|
||||
_set_vllm_config)
|
||||
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
|
||||
naive_batched_moe)
|
||||
from tests.kernels.quant_utils import dequant
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@ -511,7 +515,8 @@ def pplx_moe(
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_compile: bool = False,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> torch.Tensor:
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
num_experts = w1.shape[0]
|
||||
@ -546,6 +551,7 @@ def pplx_moe(
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||
@ -586,7 +592,11 @@ def pplx_moe(
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
out.fill_(0)
|
||||
if isinstance(out, tuple):
|
||||
out[0].fill_(0)
|
||||
out[1].fill_(0)
|
||||
else:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
@ -626,6 +636,7 @@ def _pplx_moe(
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_internode: bool = False,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
@ -666,6 +677,11 @@ def _pplx_moe(
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
if shared_experts is not None:
|
||||
shared_output = shared_experts(a)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
torch_output = torch_experts(
|
||||
a,
|
||||
w1,
|
||||
@ -696,7 +712,7 @@ def _pplx_moe(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
pplx_output = pplx_moe(
|
||||
pplx_outputs = pplx_moe(
|
||||
group_name,
|
||||
rank,
|
||||
world_size,
|
||||
@ -713,8 +729,24 @@ def _pplx_moe(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
shared_experts=shared_experts,
|
||||
)
|
||||
|
||||
if shared_experts is None:
|
||||
pplx_shared_output = None
|
||||
pplx_output = pplx_outputs
|
||||
assert isinstance(pplx_output, torch.Tensor)
|
||||
else:
|
||||
pplx_shared_output, pplx_output = pplx_outputs
|
||||
|
||||
if shared_output is not None:
|
||||
assert pplx_shared_output is not None
|
||||
chunked_shared_output = chunk_by_rank(
|
||||
shared_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_shared_output.device)
|
||||
else:
|
||||
chunked_shared_output = None
|
||||
|
||||
chunked_batch_output = chunk_by_rank(
|
||||
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
|
||||
|
||||
@ -727,6 +759,15 @@ def _pplx_moe(
|
||||
chunked_batch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
|
||||
if shared_experts is not None:
|
||||
assert chunked_shared_output is not None
|
||||
assert pplx_shared_output is not None
|
||||
torch.testing.assert_close(pplx_shared_output,
|
||||
chunked_shared_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
@ -788,7 +829,8 @@ def test_pplx_moe_slow(
|
||||
|
||||
|
||||
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
make_weights: bool, test_fn: Callable):
|
||||
use_shared_experts: bool, make_weights: bool,
|
||||
test_fn: Callable):
|
||||
|
||||
def format_result(msg, ex=None):
|
||||
if ex is not None:
|
||||
@ -803,6 +845,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
else:
|
||||
print(f"PASSED {msg}")
|
||||
|
||||
if use_shared_experts:
|
||||
# Note: this config is only needed for the non-naive shared experts.
|
||||
new_vllm_config = copy.deepcopy(vllm_config)
|
||||
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
|
||||
new_vllm_config.parallel_config.enable_expert_parallel = True
|
||||
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
|
||||
pgi.local_rank)
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
|
||||
[False, True], [None, [128, 128]])
|
||||
@ -819,9 +869,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
use_fp8_w8a8 = False
|
||||
quant_dtype = None
|
||||
|
||||
test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
|
||||
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
|
||||
f"block_shape={block_shape}")
|
||||
test_desc = (
|
||||
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
|
||||
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
|
||||
f"block_shape={block_shape}, use_internode={use_internode}, "
|
||||
f"use_shared_experts={use_shared_experts}")
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant
|
||||
or block_shape is not None):
|
||||
@ -852,6 +904,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
args["w1_s"] = w1_s
|
||||
args["w2_s"] = w2_s
|
||||
|
||||
if use_shared_experts:
|
||||
args["shared_experts"] = make_shared_experts(
|
||||
n,
|
||||
k,
|
||||
in_dtype=a.dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
)
|
||||
|
||||
try:
|
||||
test_fn(
|
||||
pgi=pgi,
|
||||
@ -891,18 +951,20 @@ def test_pplx_prepare_finalize(
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
|
||||
use_internode, False, _pplx_prepare_finalize)
|
||||
use_internode, False, False, _pplx_prepare_finalize)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@pytest.mark.parametrize("use_shared_experts", [False, True])
|
||||
@requires_pplx
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_pplx_moe(
|
||||
world_dp_size: tuple[int, int],
|
||||
use_internode: bool,
|
||||
use_shared_experts: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True,
|
||||
_pplx_moe)
|
||||
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
|
||||
use_shared_experts, True, _pplx_moe)
|
||||
|
||||
@ -8,6 +8,7 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
@ -282,3 +283,151 @@ def per_token_cast_to_fp8(
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.b = b.to(dtype=torch.float32)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return torch.mm(a.to(dtype=torch.float32),
|
||||
self.b).to(self.out_dtype), None
|
||||
|
||||
|
||||
class TestMLP(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = BaselineMM(w1, out_dtype)
|
||||
self.down_proj = BaselineMM(w2, out_dtype)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_naive_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.nn.Module:
|
||||
w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
|
||||
w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
|
||||
return TestMLP(w1, w2, out_dtype=in_dtype)
|
||||
|
||||
|
||||
class RealMLP(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
hidden_act: str = "silu",
|
||||
quant_config=None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, RowParallelLinear)
|
||||
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w1, requires_grad=False))
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
|
||||
self.gate_up_proj.register_parameter(
|
||||
"input_scale",
|
||||
None) #torch.nn.Parameter(None, requires_grad=False))
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.down_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w2, requires_grad=False))
|
||||
self.down_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
|
||||
self.down_proj.register_parameter(
|
||||
"input_scale",
|
||||
None) #torch.nn.Parameter(None, requires_grad=False))
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
) -> torch.nn.Module:
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
1,
|
||||
N,
|
||||
K,
|
||||
in_dtype=in_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
)
|
||||
old_dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(in_dtype)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
w1 = w1[0].transpose(0, 1)
|
||||
w2 = w2[0].transpose(0, 1)
|
||||
w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
|
||||
w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
|
||||
quant_config = Fp8Config(True)
|
||||
else:
|
||||
w1 = w1[0]
|
||||
w2 = w2[0]
|
||||
w1_s = None
|
||||
w2_s = None
|
||||
quant_config = None
|
||||
|
||||
return RealMLP(K,
|
||||
N,
|
||||
w1,
|
||||
w2,
|
||||
"silu",
|
||||
quant_config,
|
||||
w1_s=w1_s,
|
||||
w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -13,11 +13,6 @@ from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
else:
|
||||
FusedMoE = None
|
||||
|
||||
|
||||
class NaiveAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
|
||||
@ -252,7 +252,10 @@ class DeviceCommunicatorBase:
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
||||
# presence of quant_method.init_prepare_finalize?
|
||||
if (module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE")
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.dp_size = dp_size
|
||||
self.rank_expert_offset = rank_expert_offset
|
||||
self.async_prepare = True
|
||||
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
@ -56,10 +58,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return None
|
||||
return deep_ep.Buffer.get_combine_config(self.dp_size)
|
||||
|
||||
def _do_dispatch(self, tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor],
|
||||
rank_topk_ids: torch.Tensor,
|
||||
rank_topk_weights: torch.Tensor, num_experts: int):
|
||||
def _do_dispatch(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor],
|
||||
rank_topk_ids: torch.Tensor,
|
||||
rank_topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> Callable:
|
||||
|
||||
has_scales = token_scales is not None
|
||||
|
||||
@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_alignment=1,
|
||||
config=self._get_dispatch_config(),
|
||||
previous_event=None,
|
||||
async_finish=False,
|
||||
async_finish=self.async_prepare,
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
return lambda: self._receiver(
|
||||
event,
|
||||
has_scales,
|
||||
token_data,
|
||||
expert_topk_ids,
|
||||
num_experts,
|
||||
expert_num_tokens_per_expert_list,
|
||||
expert_topk_weights,
|
||||
a1_scale,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
event: deep_ep.EventOverlap,
|
||||
has_scales: bool,
|
||||
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
expert_topk_ids: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
expert_num_tokens_per_expert_list: list[int],
|
||||
expert_topk_weights: Optional[torch.Tensor],
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
if self.async_prepare:
|
||||
event.current_stream_wait()
|
||||
|
||||
if has_scales:
|
||||
expert_x, expert_x_scale = token_data
|
||||
else:
|
||||
@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# DeepEP's topk_ids output refers to the local experts directly. Offset
|
||||
# the topk_ids to move it back to the global experts space so it aligns
|
||||
# with existing vLLM interfaces.
|
||||
assert expert_topk_ids is not None
|
||||
expert_topk_ids = torch.where(
|
||||
expert_topk_ids == -1,
|
||||
num_experts - 1 if self.rank_expert_offset == 0 else 0,
|
||||
@ -123,10 +159,28 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
|
||||
expert_num_tokens_per_expert_list, device=expert_x.device)
|
||||
|
||||
# Dispatch and Quant
|
||||
# DeepEP kernels only support dispatching block-quantized
|
||||
# activation scales.
|
||||
# Dispatch in bfloat16 and quantize afterwards
|
||||
if not quant_config.is_block_quantized:
|
||||
# Quantize after dispatch.
|
||||
expert_x_scale = None
|
||||
if expert_x.numel() != 0:
|
||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||
expert_x,
|
||||
a1_scale,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
||||
def prepare(
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
@ -137,9 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> Callable:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@ -159,37 +211,37 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
if a1q_scale is not None and a1q_scale.numel() == 1:
|
||||
a1q_scale = a1q_scale.view(1, 1)
|
||||
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts)
|
||||
a1_post_scale = None
|
||||
else:
|
||||
# Dispatch and Quant
|
||||
# DeepEP kernels only support dispatching block-quantized
|
||||
# activation scales.
|
||||
# Dispatch in bfloat16
|
||||
(expert_x, _, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1,
|
||||
token_scales=None,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts)
|
||||
# Quantize after dispatch.
|
||||
expert_x_scale = None
|
||||
if expert_x.numel() != 0:
|
||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||
expert_x,
|
||||
a1_scale,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape)
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
a1_post_scale = a1_scale
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
return self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
||||
topk_ids, num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@ -75,7 +75,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self,
|
||||
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
a1_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
per_act_token_quant: bool,
|
||||
@ -110,7 +109,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
return x, x_scales
|
||||
|
||||
def prepare(
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
@ -121,9 +123,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.ReceiverType:
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||
@ -155,16 +155,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
return_recv_hook=True)
|
||||
|
||||
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
|
||||
a1_scale, a1.dtype, quant_config)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
hook: Callable,
|
||||
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
expert_num_tokens: torch.Tensor,
|
||||
a1_scale,
|
||||
a1_dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
hook()
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(
|
||||
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
||||
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
||||
topk_ids, num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@ -56,9 +56,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
# TODO(bnell): use quant_config + scales instead of ctor args
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
@ -506,9 +506,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.PrepareResultType:
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0)
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, overload
|
||||
from typing import Callable, Literal, Optional, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -215,6 +215,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
layer.shared_experts,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
@ -252,7 +253,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -409,7 +410,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
@ -461,7 +462,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -547,7 +548,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb is not False or expert_load_view is not None or \
|
||||
logical_to_physical_map is not None or \
|
||||
logical_replica_count is not None:
|
||||
@ -594,7 +595,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb is not False or expert_load_view is not None or \
|
||||
logical_to_physical_map is not None or \
|
||||
logical_replica_count is not None:
|
||||
@ -633,7 +634,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
@ -948,6 +949,10 @@ class FusedMoE(CustomOp):
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@ -1400,6 +1405,7 @@ class FusedMoE(CustomOp):
|
||||
return [
|
||||
weight.view(self.local_num_experts, -1) for name, weight in weights
|
||||
if name not in NON_EXPERT_WEIGHTS
|
||||
and not name.startswith("_shared_experts.")
|
||||
]
|
||||
|
||||
def set_eplb_state(
|
||||
@ -1582,25 +1588,45 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
og_hidden_states = hidden_states.shape[-1]
|
||||
if self.hidden_size != og_hidden_states:
|
||||
hidden_states = F.pad(hidden_states,
|
||||
(0, self.hidden_size - og_hidden_states),
|
||||
mode='constant',
|
||||
value=0.0)
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
return torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits,
|
||||
self.layer_name)[..., :og_hidden_states]
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
if self.shared_experts is None:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
fused_output = self.forward_impl(hidden_states, router_logits)
|
||||
assert not isinstance(fused_output, tuple)
|
||||
else:
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, self.layer_name)
|
||||
return fused_output[..., :og_hidden_states]
|
||||
else:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
shared_output, fused_output = self.forward_impl(
|
||||
hidden_states, router_logits)
|
||||
else:
|
||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||
hidden_states, router_logits, self.layer_name)
|
||||
return (shared_output[..., :og_hidden_states],
|
||||
fused_output[..., :og_hidden_states])
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||
@ -1611,7 +1637,10 @@ class FusedMoE(CustomOp):
|
||||
assert (
|
||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||
|
||||
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(
|
||||
full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
@ -1652,9 +1681,21 @@ class FusedMoE(CustomOp):
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
assert self.shared_experts is None or isinstance(
|
||||
final_hidden_states, tuple)
|
||||
|
||||
if not skip_result_store:
|
||||
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states, non_blocking=True)
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states,
|
||||
non_blocking=True)
|
||||
else:
|
||||
full_shared_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states[0],
|
||||
non_blocking=True)
|
||||
full_fused_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states[1],
|
||||
non_blocking=True)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
@ -1675,10 +1716,17 @@ class FusedMoE(CustomOp):
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= num_tokens)
|
||||
|
||||
return full_final_hidden_states
|
||||
if self.shared_experts is None:
|
||||
return full_fused_final_hidden_states
|
||||
else:
|
||||
return (full_shared_final_hidden_states,
|
||||
full_fused_final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.quant_method is not None
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
@ -1698,6 +1746,15 @@ class FusedMoE(CustomOp):
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
# shared experts must be called here
|
||||
if (not isinstance(self.quant_method.fused_experts,
|
||||
FusedMoEModularKernel)
|
||||
and self.shared_experts is not None):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@ -1722,14 +1779,30 @@ class FusedMoE(CustomOp):
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
return final_hidden_states
|
||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is None:
|
||||
return reduce_output(final_hidden_states)
|
||||
else:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0]),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
@ -1784,17 +1857,22 @@ class FusedMoE(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
def moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.quant_method is not None
|
||||
|
||||
assert self.shared_experts is None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
def moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@ -1807,6 +1885,37 @@ direct_register_custom_op(
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.shared_experts is not None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward_shared",
|
||||
op_func=moe_forward_shared,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_shared_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
|
||||
# to avoid expensive runtime reflection in model loading code
|
||||
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
||||
|
||||
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional, final
|
||||
from typing import Callable, Optional, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#
|
||||
# PrepareResultType is a tuple of:
|
||||
# - quantized + dispatched a.
|
||||
# - quantized + dispatched a1_scales.
|
||||
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
# as big as the number of local experts with the information about the
|
||||
# number of tokens assigned to each local expert.
|
||||
# - Optional dispatched expert topk IDs
|
||||
# - Optional dispatched expert topk weight
|
||||
#
|
||||
# See `prepare` method below.
|
||||
#
|
||||
PrepareResultType = tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
) -> PrepareResultType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not this class implements prepare_async.
|
||||
"""
|
||||
return False
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> ReceiverType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
sure the quantization is consistent for both gemms.
|
||||
- topk_ids: The topk ids.
|
||||
- topk_weights: The topk weights.
|
||||
- num_experts: The total number of experts in the global expert space.
|
||||
- expert_map: A tensor mapping expert indices from the global expert
|
||||
space to the local expert space of the expert parallel shard.
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `prepare`, e.g.
|
||||
|
||||
receiver = obj.prepare_async(...)
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
||||
|
||||
is equivalent to:
|
||||
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(
|
||||
self,
|
||||
@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
of weights, w1 and w2, and top-k gating mechanism.
|
||||
@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
shared_output: torch.Tensor
|
||||
|
||||
if (not self.prepare_finalize.supports_async()
|
||||
or self.shared_experts is None):
|
||||
|
||||
# Run shared experts serially with dispatch.
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
receiver = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
assert self.shared_experts is not None
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
return output
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
return shared_output, output
|
||||
|
||||
@ -84,12 +84,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return torch.int32
|
||||
return torch.uint32
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def prepare(
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
@ -100,9 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.ReceiverType:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
@ -138,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
_validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant,
|
||||
quant_config.block_shape)
|
||||
|
||||
orig_a_scale_block_shape: Optional[int] = None
|
||||
|
||||
if a1q_scale is not None:
|
||||
scalar_scales = a1q_scale.numel() == 1
|
||||
|
||||
@ -205,8 +208,45 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids.view(dtype=torch.uint32),
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=True,
|
||||
do_recv=False,
|
||||
)
|
||||
|
||||
return lambda: self._receiver(
|
||||
expert_num_tokens,
|
||||
expert_x,
|
||||
expert_x_scale,
|
||||
a1q,
|
||||
a1q_scale,
|
||||
topk_ids,
|
||||
bound_m,
|
||||
orig_a_scale_block_shape,
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
expert_x: torch.Tensor,
|
||||
expert_x_scale: Optional[torch.Tensor],
|
||||
a1q: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
topk_ids: torch.Tensor,
|
||||
bound_m: Optional[torch.Tensor],
|
||||
orig_a_scale_block_shape: Optional[int],
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=False,
|
||||
do_recv=True,
|
||||
)
|
||||
|
||||
if expert_x_scale is not None:
|
||||
@ -218,6 +258,31 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
|
||||
@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -505,7 +505,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@ -474,7 +474,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert self.fused_experts is None
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
@ -358,7 +358,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@ -819,7 +819,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for "
|
||||
@ -1069,7 +1069,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@ -1375,7 +1375,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@ -1608,7 +1608,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -128,7 +128,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -988,7 +988,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@ -540,7 +540,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@ -654,7 +654,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@ -491,7 +491,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||
@ -1366,7 +1366,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -305,7 +305,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -554,7 +554,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -226,7 +226,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@ -390,7 +390,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -291,7 +291,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
6
vllm/model_executor/layers/shared_fused_moe/__init__.py
Normal file
6
vllm/model_executor/layers/shared_fused_moe/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import (
|
||||
SharedFusedMoE)
|
||||
|
||||
__all__ = ["SharedFusedMoE"]
|
||||
@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
|
||||
# TODO(bnell): Add shared + fused combo function? e.g. +
|
||||
class SharedFusedMoE(FusedMoE):
|
||||
"""
|
||||
A FusedMoE operation that also computes the results of shared experts.
|
||||
If an all2all communicator is being used the shared expert computation
|
||||
can be interleaved with the fused all2all dispatch communication step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module,
|
||||
use_overlapped: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self.use_overlapped = use_overlapped
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
return self._shared_experts if self.use_overlapped else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.use_overlapped:
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
|
||||
# Reduce outputs if necessary, since the MLP should
|
||||
# have been created with reduce_results=False.
|
||||
if (self.reduce_results and self.tp_size > 1
|
||||
and self.must_reduce_shared_expert_outputs()):
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
|
||||
fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
shared_out, fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
return shared_out, fused_out
|
||||
@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@ -147,63 +148,85 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
if config.n_shared_experts is None:
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits) * self.routed_scaling_factor
|
||||
fused_moe_out = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output, final_hidden_states = fused_moe_out
|
||||
else:
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
if shared_output is not None:
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
else:
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = final_hidden_states + shared_output \
|
||||
* (1. / self.routed_scaling_factor)
|
||||
shared_output = None
|
||||
final_hidden_states = fused_moe_out
|
||||
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
shared_output *= (1. / self.routed_scaling_factor)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
|
||||
@ -184,6 +184,8 @@ class Glm4MoE(nn.Module):
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
|
||||
@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router")
|
||||
|
||||
self.experts = FusedMoE(
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_expert,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
|
||||
reduce_results=False,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
routed_out = self.experts(
|
||||
|
||||
shared_out, routed_out = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@ -500,7 +500,8 @@ class Worker(WorkerBase):
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
moe_modules = [
|
||||
module for module in self.model_runner.model.modules()
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
if (module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE")
|
||||
]
|
||||
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
||||
assert all(module.moe_config.num_local_experts == num_local_experts
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user