[UX] Separate marlin moe config logic from triton moe (#23006)

This commit is contained in:
Michael Goin 2025-08-16 22:16:42 -04:00 committed by GitHub
parent a258ad8bcc
commit 94096a47c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 22 deletions

View File

@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ."""
import functools
from typing import Optional
import torch
import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types
@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
N = w2.shape[1] * 16
topk = topk_ids.shape[1]
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
# M block size selection logic
# TODO: tune this further for specific models
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break
if global_num_experts == -1:
global_num_experts = E

View File

@ -801,7 +801,6 @@ def get_default_config(
K: int,
topk: int,
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[list[int]] = None,
) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None:
@ -832,11 +831,6 @@ def get_default_config(
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
elif is_marlin:
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break
return {"BLOCK_SIZE_M": block_size_m}
elif M <= E:
config = {
"BLOCK_SIZE_M": 16,
@ -860,7 +854,6 @@ def try_get_optimal_moe_config(
top_k: int,
dtype: Optional[str],
M: int,
is_marlin: bool = False,
block_shape: Optional[list[int]] = None,
) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config
@ -883,7 +876,7 @@ def try_get_optimal_moe_config(
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin, block_shape)
block_shape)
return config