mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[UX] Separate marlin moe config logic from triton moe (#23006)
This commit is contained in:
parent
a258ad8bcc
commit
94096a47c9
@ -1,14 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Fused MoE utilities for GPTQ."""
|
"""Fused MoE utilities for GPTQ."""
|
||||||
import functools
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||||
moe_align_block_size, try_get_optimal_moe_config)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
|
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
N = w2.shape[1] * 16
|
N = w2.shape[1] * 16
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
get_config_func = functools.partial(
|
# M block size selection logic
|
||||||
try_get_optimal_moe_config,
|
# TODO: tune this further for specific models
|
||||||
w1.shape,
|
for block_size_m in [8, 16, 32, 48, 64]:
|
||||||
w2.shape,
|
if M * topk / E / block_size_m < 0.9:
|
||||||
topk_ids.shape[1],
|
break
|
||||||
None,
|
|
||||||
is_marlin=True,
|
|
||||||
)
|
|
||||||
config = get_config_func(M)
|
|
||||||
|
|
||||||
block_size_m = config["BLOCK_SIZE_M"]
|
|
||||||
|
|
||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
global_num_experts = E
|
global_num_experts = E
|
||||||
|
|||||||
@ -801,7 +801,6 @@ def get_default_config(
|
|||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
is_marlin: bool,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||||
@ -832,11 +831,6 @@ def get_default_config(
|
|||||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
||||||
else:
|
else:
|
||||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
||||||
elif is_marlin:
|
|
||||||
for block_size_m in [8, 16, 32, 48, 64]:
|
|
||||||
if M * topk / E / block_size_m < 0.9:
|
|
||||||
break
|
|
||||||
return {"BLOCK_SIZE_M": block_size_m}
|
|
||||||
elif M <= E:
|
elif M <= E:
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
@ -860,7 +854,6 @@ def try_get_optimal_moe_config(
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
M: int,
|
M: int,
|
||||||
is_marlin: bool = False,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
from vllm.model_executor.layers.fused_moe import get_config
|
from vllm.model_executor.layers.fused_moe import get_config
|
||||||
@ -883,7 +876,7 @@ def try_get_optimal_moe_config(
|
|||||||
else:
|
else:
|
||||||
# Else use the default config
|
# Else use the default config
|
||||||
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
|
||||||
is_marlin, block_shape)
|
block_shape)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user