mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 09:06:58 +08:00
[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (#25488)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
767cbb011d
commit
7ef40bb983
@ -93,6 +93,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
|
|||||||
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
||||||
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
|
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
|
||||||
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
|
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
|
||||||
|
|
||||||
|
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
|
||||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
||||||
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
||||||
@ -114,6 +116,6 @@ The following table shows "families" of modular kernels that are intended to wor
|
|||||||
|
|
||||||
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|
||||||
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
|
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
|
||||||
| deepep_high_throughput,</br>pplx | `DeepEPHTPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8` |
|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
||||||
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8` |
|
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
|
||||||
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
||||||
|
|||||||
@ -303,7 +303,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
assert w2.size(1) == K
|
assert w2.size(1) == K
|
||||||
|
|
||||||
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||||
hidden_states, w1, w2, topk_ids)
|
hidden_states, w1, w2, topk_ids)
|
||||||
|
|
||||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||||
|
|||||||
@ -712,7 +712,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
|
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
|
||||||
n = w2.shape[2] * 2
|
n = w2.shape[2] * 2
|
||||||
|
|
||||||
run_cutlass_moe_fp4(
|
run_cutlass_moe_fp4(
|
||||||
|
|||||||
@ -906,7 +906,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||||
|
|
||||||
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||||
hidden_states, w1, w2, topk_ids)
|
hidden_states, w1, w2, topk_ids)
|
||||||
|
|
||||||
assert w1.size(0) == E
|
assert w1.size(0) == E
|
||||||
|
|||||||
@ -4,11 +4,18 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||||
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
|
TopKWeightAndReduceNoOP)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
|
marlin_make_workspace_new, marlin_moe_intermediate_size,
|
||||||
|
maybe_warn_marlin_atomic_add)
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
bias2: Optional[torch.Tensor],
|
bias2: Optional[torch.Tensor],
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: Optional[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
quant_type_id: int,
|
quant_type_id: int,
|
||||||
@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
w1_zeros: Optional[torch.Tensor] = None,
|
w1_zeros: Optional[torch.Tensor] = None,
|
||||||
w2_zeros: Optional[torch.Tensor] = None,
|
w2_zeros: Optional[torch.Tensor] = None,
|
||||||
workspace: Optional[torch.Tensor] = None,
|
workspace: Optional[torch.Tensor] = None,
|
||||||
|
intermediate_cache13: Optional[torch.Tensor] = None,
|
||||||
|
intermediate_cache2: Optional[torch.Tensor] = None,
|
||||||
is_k_full: bool = True,
|
is_k_full: bool = True,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False) -> torch.Tensor:
|
inplace: bool = False) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
- w2 (torch.Tensor): The second set of expert weights.
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||||
- gating_output (torch.Tensor): The output of the gating operation
|
- gating_output (Optional[torch.Tensor]): The output of the gating
|
||||||
(before softmax).
|
operation (before softmax).
|
||||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||||
@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
num_bits = 4 if quant_type in bit4_scalar_types else 8
|
num_bits = 4 if quant_type in bit4_scalar_types else 8
|
||||||
|
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert hidden_states.shape[0] == gating_output.shape[
|
if gating_output is not None:
|
||||||
0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[
|
||||||
|
0], "Number of tokens mismatch"
|
||||||
assert hidden_states.shape[
|
assert hidden_states.shape[
|
||||||
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||||
assert hidden_states.shape[1] == w2.shape[2] // (
|
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||||
@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
M, K = hidden_states.shape
|
M, K = hidden_states.shape
|
||||||
E = w1.shape[0]
|
E = w1.shape[0]
|
||||||
N = w2.shape[1] * 16
|
N = marlin_moe_intermediate_size(w1, w2)
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
# M block size selection logic
|
# M block size selection logic
|
||||||
@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
if workspace is None:
|
if workspace is None:
|
||||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||||
|
|
||||||
intermediate_cache2 = torch.empty(
|
if intermediate_cache2 is None:
|
||||||
(M * topk_ids.shape[1], N),
|
intermediate_cache2 = torch.empty(
|
||||||
device=hidden_states.device,
|
(M * topk, N),
|
||||||
dtype=hidden_states.dtype,
|
device=hidden_states.device,
|
||||||
)
|
dtype=hidden_states.dtype,
|
||||||
intermediate_cache13 = torch.empty(
|
)
|
||||||
(M * topk_ids.shape[1] * max(2 * N, K), ),
|
|
||||||
device=hidden_states.device,
|
if intermediate_cache13 is None:
|
||||||
dtype=hidden_states.dtype,
|
intermediate_cache13 = torch.empty(
|
||||||
)
|
(M * topk * max(2 * N, K), ),
|
||||||
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
|
device=hidden_states.device,
|
||||||
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
|
dtype=hidden_states.dtype,
|
||||||
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
|
)
|
||||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
|
||||||
|
intermediate_cache1 = _resize_cache(intermediate_cache13,
|
||||||
|
(M * topk, 2 * N))
|
||||||
|
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
|
||||||
|
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
|
||||||
|
|
||||||
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||||
use_atomic_add = hidden_states.dtype == torch.half or \
|
use_atomic_add = hidden_states.dtype == torch.half or \
|
||||||
@ -200,10 +215,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
use_fp32_reduce=True,
|
use_fp32_reduce=True,
|
||||||
is_zp_float=False).view(-1, topk, K)
|
is_zp_float=False).view(-1, topk, K)
|
||||||
|
|
||||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
if output is None:
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||||
dim=1,
|
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
|
||||||
out=output)
|
|
||||||
|
|
||||||
|
|
||||||
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||||
@ -211,7 +225,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: Optional[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
quant_type_id: int,
|
quant_type_id: int,
|
||||||
@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
|||||||
w1_zeros: Optional[torch.Tensor] = None,
|
w1_zeros: Optional[torch.Tensor] = None,
|
||||||
w2_zeros: Optional[torch.Tensor] = None,
|
w2_zeros: Optional[torch.Tensor] = None,
|
||||||
workspace: Optional[torch.Tensor] = None,
|
workspace: Optional[torch.Tensor] = None,
|
||||||
|
intermediate_cache13: Optional[torch.Tensor] = None,
|
||||||
|
intermediate_cache2: Optional[torch.Tensor] = None,
|
||||||
is_k_full: bool = True,
|
is_k_full: bool = True,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False) -> torch.Tensor:
|
inplace: bool = False) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
@ -237,3 +254,124 @@ direct_register_custom_op(
|
|||||||
op_func=fused_marlin_moe,
|
op_func=fused_marlin_moe,
|
||||||
fake_impl=fused_marlin_moe_fake,
|
fake_impl=fused_marlin_moe_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||||
|
# TODO (varun) : Enable activation quantization
|
||||||
|
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def moe_problem_size(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> tuple[int, int, int, int, int]:
|
||||||
|
assert w1.dim() == 3 and w2.dim() == 3
|
||||||
|
|
||||||
|
E = w1.size(0)
|
||||||
|
K = a1.size(-1)
|
||||||
|
N = marlin_moe_intermediate_size(w1, w2)
|
||||||
|
|
||||||
|
if a1.dim() == 2:
|
||||||
|
# Make sure we are using the correct a1 (pre-permute).
|
||||||
|
assert topk_ids.size(0) == a1.size(0), \
|
||||||
|
f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||||
|
M = a1.size(0)
|
||||||
|
else:
|
||||||
|
assert a1.dim() == 3
|
||||||
|
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||||
|
M = a1.size(1) # This is max_num_tokens
|
||||||
|
|
||||||
|
assert topk_ids.dim() == 2
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
|
||||||
|
return E, M, N, K, topk
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
return TopKWeightAndReduceNoOP()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
|
def supports_chunking(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||||
|
topk: int, global_num_experts: int, local_num_experts: int,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
|
# Modular Kernel provisions output buffer from workspace1. However in
|
||||||
|
# the fused_marlin_moe() function, the final torch.sum(), is defined
|
||||||
|
# essentially as,
|
||||||
|
# `torch.sum(workspace1, dim=1, out=output)`
|
||||||
|
# Having overlapping input and output tensors for torch.sum seems
|
||||||
|
# error prone and depends on how the torch.sum is implemented.
|
||||||
|
# For this reason we swap let the output buffer provision from
|
||||||
|
# workspace2.
|
||||||
|
|
||||||
|
# Workspace/IntermediateCache allocation matching fused_marlin_moe()
|
||||||
|
#workspace1 = (M * topk * max(2 * N, K),)
|
||||||
|
#workspace2 = (M * topk, N)
|
||||||
|
|
||||||
|
# Workspace/IntermediateCache allocation accounting for output buffer
|
||||||
|
# provisioning
|
||||||
|
workspace1 = (M * topk, max(N, K))
|
||||||
|
workspace2 = (M * topk * max(2 * N, K), )
|
||||||
|
output = (M, K)
|
||||||
|
|
||||||
|
return (workspace1, workspace2, output, a.dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
|
assert self.w1_scale is not None
|
||||||
|
assert self.w2_scale is not None
|
||||||
|
return fused_marlin_moe(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
bias1=self.w1_bias,
|
||||||
|
bias2=self.w2_bias,
|
||||||
|
w1_scale=self.w1_scale,
|
||||||
|
w2_scale=self.w2_scale,
|
||||||
|
gating_output=None,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
activation=activation,
|
||||||
|
expert_map=expert_map,
|
||||||
|
output=output,
|
||||||
|
# Workspaces are swapped in workspace_shapes() to account for proper
|
||||||
|
# output buffer allocation. Please refer to workspace_shapes().
|
||||||
|
intermediate_cache13=workspace2,
|
||||||
|
intermediate_cache2=workspace13)
|
||||||
|
|||||||
@ -1780,7 +1780,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||||
]
|
]
|
||||||
|
|
||||||
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||||
hidden_states, w1, w2, topk_ids)
|
hidden_states, w1, w2, topk_ids)
|
||||||
|
|
||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
|
|||||||
@ -55,46 +55,6 @@ from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def _moe_problem_size(
|
|
||||||
a1: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
) -> tuple[int, int, int, int, int]:
|
|
||||||
"""
|
|
||||||
Extract the MoE problem size from the given tensor arguments:
|
|
||||||
- a: The hidden states, input to the MoE layer.
|
|
||||||
- w1: The first set of expert weights.
|
|
||||||
- w2: The second set of expert weights.
|
|
||||||
- topk_ids: The topk ids.
|
|
||||||
|
|
||||||
Note: extracting the problem shape from the weight and activation tensors is
|
|
||||||
not obvious. It needs to be done this way specifically due to subtle issues
|
|
||||||
with particular kernels, e.g. the int4 kernels divide the trailing dimension
|
|
||||||
by two, so it's not "correct" to extract N or K from the trailing dimension
|
|
||||||
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
|
|
||||||
to be kept in mind.
|
|
||||||
"""
|
|
||||||
assert w1.dim() == 3 and w2.dim() == 3
|
|
||||||
E, N, _ = w1.size()
|
|
||||||
K = a1.size(-1)
|
|
||||||
|
|
||||||
if a1.dim() == 2:
|
|
||||||
# Make sure we are using the correct a1 (pre-permute).
|
|
||||||
assert topk_ids.size(0) == a1.size(0), \
|
|
||||||
f"{topk_ids.size(0)} != {a1.size(0)}"
|
|
||||||
M = a1.size(0)
|
|
||||||
else:
|
|
||||||
assert a1.dim() == 3
|
|
||||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
|
||||||
M = a1.size(1) # This is max_num_tokens
|
|
||||||
|
|
||||||
assert topk_ids.dim() == 2
|
|
||||||
topk = topk_ids.size(1)
|
|
||||||
|
|
||||||
return E, M, N, K, topk
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEActivationFormat(Enum):
|
class FusedMoEActivationFormat(Enum):
|
||||||
"""
|
"""
|
||||||
The standard activation format (num_tokens, hidden dim).
|
The standard activation format (num_tokens, hidden dim).
|
||||||
@ -391,6 +351,50 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def moe_problem_size(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> tuple[int, int, int, int, int]:
|
||||||
|
"""
|
||||||
|
Extract the MoE problem size from the given tensor arguments:
|
||||||
|
- a: The hidden states, input to the MoE layer.
|
||||||
|
- w1: The first set of expert weights.
|
||||||
|
- w2: The second set of expert weights.
|
||||||
|
- topk_ids: The topk ids.
|
||||||
|
|
||||||
|
Note: extracting the problem shape from the weight and activation
|
||||||
|
tensors is not obvious. It needs to be done this way specifically
|
||||||
|
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||||
|
divide the trailing dimension by two, so it's not "correct" to
|
||||||
|
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||||
|
some kernels transpose the weights, so this needs to be kept in mind.
|
||||||
|
|
||||||
|
Note: This implementation covers most cases. However, if experts
|
||||||
|
require a specialized implementation, like MarlinExperts, they are free
|
||||||
|
to override this function.
|
||||||
|
"""
|
||||||
|
assert w1.dim() == 3 and w2.dim() == 3
|
||||||
|
E, N, _ = w1.size()
|
||||||
|
K = a1.size(-1)
|
||||||
|
|
||||||
|
if a1.dim() == 2:
|
||||||
|
# Make sure we are using the correct a1 (pre-permute).
|
||||||
|
assert topk_ids.size(0) == a1.size(0), \
|
||||||
|
f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||||
|
M = a1.size(0)
|
||||||
|
else:
|
||||||
|
assert a1.dim() == 3
|
||||||
|
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||||
|
M = a1.size(1) # This is max_num_tokens
|
||||||
|
|
||||||
|
assert topk_ids.dim() == 2
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
|
||||||
|
return E, M, N, K, topk
|
||||||
|
|
||||||
#
|
#
|
||||||
# Various helpers for accessing quantization parameters from the
|
# Various helpers for accessing quantization parameters from the
|
||||||
# quant_config.
|
# quant_config.
|
||||||
@ -674,7 +678,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||||
|
a1q, w1, w2, topk_ids)
|
||||||
|
|
||||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||||
@ -737,7 +742,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||||
|
a1q, w1, w2, topk_ids)
|
||||||
|
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
|||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
|
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
|
||||||
mxfp4_w4a16_moe_quant_config)
|
mxfp4_w4a16_moe_quant_config)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
|
||||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||||
OAITritonExperts)
|
OAITritonExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||||
@ -92,7 +93,7 @@ def get_mxfp4_backend():
|
|||||||
"Please `pip install vllm[flashinfer]` for best results.")
|
"Please `pip install vllm[flashinfer]` for best results.")
|
||||||
|
|
||||||
# If FlashInfer is not available, try either Marlin or Triton
|
# If FlashInfer is not available, try either Marlin or Triton
|
||||||
if current_platform.get_device_capability(
|
if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability(
|
||||||
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
|
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
|
||||||
"2.8.0"):
|
"2.8.0"):
|
||||||
logger.info_once("Using Marlin backend")
|
logger.info_once("Using Marlin backend")
|
||||||
@ -646,9 +647,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
|
||||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||||
return None
|
return mxfp4_w4a16_moe_quant_config(
|
||||||
|
w1_bias=layer.w13_bias,
|
||||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
w2_bias=layer.w2_bias,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
)
|
||||||
|
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||||
w1_scale = self.w13_precision_config
|
w1_scale = self.w13_precision_config
|
||||||
w2_scale = self.w2_precision_config
|
w2_scale = self.w2_precision_config
|
||||||
return mxfp4_w4a16_moe_quant_config(
|
return mxfp4_w4a16_moe_quant_config(
|
||||||
@ -690,6 +695,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
}
|
}
|
||||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
elif (self.mxfp4_backend == Mxfp4Backend.MARLIN):
|
||||||
|
return MarlinExperts(self.moe_quant_config)
|
||||||
else:
|
else:
|
||||||
return OAITritonExperts(self.moe_quant_config)
|
return OAITritonExperts(self.moe_quant_config)
|
||||||
|
|
||||||
@ -782,6 +789,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||||
|
|
||||||
|
if self.fused_experts is not None:
|
||||||
|
return self._route_and_experts(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize,
|
||||||
|
use_grouped_topk,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
custom_routing_function,
|
||||||
|
scoring_func,
|
||||||
|
e_score_correction_bias,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
activation,
|
||||||
|
enable_eplb,
|
||||||
|
expert_load_view,
|
||||||
|
logical_to_physical_map,
|
||||||
|
logical_replica_count,
|
||||||
|
)
|
||||||
|
|
||||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -815,29 +845,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
if self.fused_experts is not None:
|
|
||||||
return self._route_and_experts(
|
|
||||||
layer,
|
|
||||||
x,
|
|
||||||
router_logits,
|
|
||||||
top_k,
|
|
||||||
renormalize,
|
|
||||||
use_grouped_topk,
|
|
||||||
topk_group,
|
|
||||||
num_expert_group,
|
|
||||||
global_num_experts,
|
|
||||||
expert_map,
|
|
||||||
custom_routing_function,
|
|
||||||
scoring_func,
|
|
||||||
e_score_correction_bias,
|
|
||||||
apply_router_weight_on_input,
|
|
||||||
activation,
|
|
||||||
enable_eplb,
|
|
||||||
expert_load_view,
|
|
||||||
logical_to_physical_map,
|
|
||||||
logical_replica_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert _can_support_mxfp4(
|
assert _can_support_mxfp4(
|
||||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||||
custom_routing_function, e_score_correction_bias,
|
custom_routing_function, e_score_correction_bias,
|
||||||
|
|||||||
@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
|||||||
supports_router_weight and supports_activation
|
supports_router_weight and supports_activation
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_moe_intermediate_size(w1_packed: torch.Tensor,
|
||||||
|
w2_packed: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Given Marlin packed weight matrices w1_packed, and w2_packed,
|
||||||
|
return the MoE intermediate size N
|
||||||
|
"""
|
||||||
|
marlin_tile_size = 16
|
||||||
|
return w2_packed.size(1) * marlin_tile_size
|
||||||
|
|
||||||
|
|
||||||
def marlin_make_workspace(output_size_per_partition: int,
|
def marlin_make_workspace(output_size_per_partition: int,
|
||||||
device: torch.device) -> torch.Tensor:
|
device: torch.device) -> torch.Tensor:
|
||||||
max_workspace_size = (output_size_per_partition //
|
max_workspace_size = (output_size_per_partition //
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user