[UX] Separate marlin moe config logic from triton moe (#23006)

This commit is contained in:
Michael Goin 2025-08-16 22:16:42 -04:00 committed by GitHub
parent a258ad8bcc
commit 94096a47c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 22 deletions

View File

@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ.""" """Fused MoE utilities for GPTQ."""
import functools
from typing import Optional from typing import Optional
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
moe_align_block_size, try_get_optimal_moe_config)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add) marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
N = w2.shape[1] * 16 N = w2.shape[1] * 16
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
get_config_func = functools.partial( # M block size selection logic
try_get_optimal_moe_config, # TODO: tune this further for specific models
w1.shape, for block_size_m in [8, 16, 32, 48, 64]:
w2.shape, if M * topk / E / block_size_m < 0.9:
topk_ids.shape[1], break
None,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E

View File

@ -801,7 +801,6 @@ def get_default_config(
K: int, K: int,
topk: int, topk: int,
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> dict[str, int]: ) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
@ -832,11 +831,6 @@ def get_default_config(
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else: else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
elif is_marlin:
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break
return {"BLOCK_SIZE_M": block_size_m}
elif M <= E: elif M <= E:
config = { config = {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
@ -860,7 +854,6 @@ def try_get_optimal_moe_config(
top_k: int, top_k: int,
dtype: Optional[str], dtype: Optional[str],
M: int, M: int,
is_marlin: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> dict[str, int]: ) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
@ -883,7 +876,7 @@ def try_get_optimal_moe_config(
else: else:
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin, block_shape) block_shape)
return config return config