[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (#25488)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-10-03 20:13:13 -04:00 committed by GitHub
parent 767cbb011d
commit 7ef40bb983
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 264 additions and 101 deletions

View File

@ -93,6 +93,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
@ -114,6 +116,6 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
| deepep_high_throughput,</br>pplx | `DeepEPHTPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8` |
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8` |
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |

View File

@ -303,7 +303,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w2.size(1) == K
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))

View File

@ -712,7 +712,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_fp4(

View File

@ -906,7 +906,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = expert_tokens_meta.expert_num_tokens
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids)
assert w1.size(0) == E

View File

@ -4,11 +4,18 @@
from typing import Optional
import torch
from typing_extensions import override
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
marlin_make_workspace_new, marlin_moe_intermediate_size,
maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op
@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
bias2: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
gating_output: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
intermediate_cache13: Optional[torch.Tensor] = None,
intermediate_cache2: Optional[torch.Tensor] = None,
is_k_full: bool = True,
output: Optional[torch.Tensor] = None,
inplace: bool = False) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- gating_output (Optional[torch.Tensor]): The output of the gating
operation (before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[
0], "Number of tokens mismatch"
if gating_output is not None:
assert hidden_states.shape[0] == gating_output.shape[
0], "Number of tokens mismatch"
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
N = marlin_moe_intermediate_size(w1, w2)
topk = topk_ids.shape[1]
# M block size selection logic
@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
if intermediate_cache2 is None:
intermediate_cache2 = torch.empty(
(M * topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if intermediate_cache13 is None:
intermediate_cache13 = torch.empty(
(M * topk * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = _resize_cache(intermediate_cache13,
(M * topk, 2 * N))
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = hidden_states.dtype == torch.half or \
@ -200,10 +215,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
use_fp32_reduce=True,
is_zp_float=False).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=output)
if output is None:
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
@ -211,7 +225,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
gating_output: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
intermediate_cache13: Optional[torch.Tensor] = None,
intermediate_cache2: Optional[torch.Tensor] = None,
is_k_full: bool = True,
output: Optional[torch.Tensor] = None,
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states)
@ -237,3 +254,124 @@ direct_register_custom_op(
op_func=fused_marlin_moe,
fake_impl=fused_marlin_moe_fake,
)
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
@override
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
assert w1.dim() == 3 and w2.dim() == 3
E = w1.size(0)
K = a1.size(-1)
N = marlin_moe_intermediate_size(w1, w2)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
# essentially as,
# `torch.sum(workspace1, dim=1, out=output)`
# Having overlapping input and output tensors for torch.sum seems
# error prone and depends on how the torch.sum is implemented.
# For this reason we swap let the output buffer provision from
# workspace2.
# Workspace/IntermediateCache allocation matching fused_marlin_moe()
#workspace1 = (M * topk * max(2 * N, K),)
#workspace2 = (M * topk, N)
# Workspace/IntermediateCache allocation accounting for output buffer
# provisioning
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk * max(2 * N, K), )
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert self.w1_scale is not None
assert self.w2_scale is not None
return fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
bias1=self.w1_bias,
bias2=self.w2_bias,
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
output=output,
# Workspaces are swapped in workspace_shapes() to account for proper
# output buffer allocation. Please refer to workspace_shapes().
intermediate_cache13=workspace2,
intermediate_cache2=workspace13)

View File

@ -1780,7 +1780,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids)
if global_num_experts == -1:

View File

@ -55,46 +55,6 @@ from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
#
def _moe_problem_size(
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation tensors is
not obvious. It needs to be done this way specifically due to subtle issues
with particular kernels, e.g. the int4 kernels divide the trailing dimension
by two, so it's not "correct" to extract N or K from the trailing dimension
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
to be kept in mind.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
class FusedMoEActivationFormat(Enum):
"""
The standard activation format (num_tokens, hidden dim).
@ -391,6 +351,50 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
#
# Various helpers for accessing quantization parameters from the
# quant_config.
@ -674,7 +678,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
@ -737,7 +742,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)

View File

@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
@ -92,7 +93,7 @@ def get_mxfp4_backend():
"Please `pip install vllm[flashinfer]` for best results.")
# If FlashInfer is not available, try either Marlin or Triton
if current_platform.get_device_capability(
if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability(
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
"2.8.0"):
logger.info_once("Using Marlin backend")
@ -646,9 +647,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return None
if self.mxfp4_backend == Mxfp4Backend.TRITON:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config(
@ -690,6 +695,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
}
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
elif (self.mxfp4_backend == Mxfp4Backend.MARLIN):
return MarlinExperts(self.moe_quant_config)
else:
return OAITritonExperts(self.moe_quant_config)
@ -782,6 +789,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.fused_experts is not None:
return self._route_and_experts(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
@ -815,29 +845,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation,
expert_map=expert_map)
if self.fused_experts is not None:
return self._route_and_experts(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias,

View File

@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
supports_router_weight and supports_activation
def marlin_moe_intermediate_size(w1_packed: torch.Tensor,
w2_packed: torch.Tensor):
"""
Given Marlin packed weight matrices w1_packed, and w2_packed,
return the MoE intermediate size N
"""
marlin_tile_size = 16
return w2_packed.size(1) * marlin_tile_size
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //