From 94096a47c92c4a53ad44cfffdca918669c0f89e0 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 16 Aug 2025 22:16:42 -0400 Subject: [PATCH] [UX] Separate marlin moe config logic from triton moe (#23006) --- .../layers/fused_moe/fused_marlin_moe.py | 20 ++++++------------- .../layers/fused_moe/fused_moe.py | 9 +-------- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index a49d41c18438e..3c6ece6737e4d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE utilities for GPTQ.""" -import functools from typing import Optional import torch import vllm._custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_moe import ( - moe_align_block_size, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, maybe_warn_marlin_atomic_add) from vllm.scalar_type import ScalarType, scalar_types @@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor, N = w2.shape[1] * 16 topk = topk_ids.shape[1] - get_config_func = functools.partial( - try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - None, - is_marlin=True, - ) - config = get_config_func(M) - - block_size_m = config["BLOCK_SIZE_M"] + # M block size selection logic + # TODO: tune this further for specific models + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break if global_num_experts == -1: global_num_experts = E diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e58a9e568d4a4..3579ca22bafc7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -801,7 +801,6 @@ def get_default_config( K: int, topk: int, dtype: Optional[str], - is_marlin: bool, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: @@ -832,11 +831,6 @@ def get_default_config( config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} else: config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} - elif is_marlin: - for block_size_m in [8, 16, 32, 48, 64]: - if M * topk / E / block_size_m < 0.9: - break - return {"BLOCK_SIZE_M": block_size_m} elif M <= E: config = { "BLOCK_SIZE_M": 16, @@ -860,7 +854,6 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - is_marlin: bool = False, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config @@ -883,7 +876,7 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin, block_shape) + block_shape) return config