mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 02:45:19 +08:00
[LoRA] Support FusedMoE LoRA Triton kernel for mxfp4 (#28971)
Signed-off-by: Xin Yang <xyangx@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
35657bcd7a
commit
745a3bae1a
250
tests/kernels/moe/test_modular_oai_triton_moe.py
Normal file
250
tests/kernels/moe/test_modular_oai_triton_moe.py
Normal file
@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test modular OAI Triton MoE
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
pytest.skip(
|
||||
"triton_kernels not found, skipping all related tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import shuffle_weight
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MNK = [
|
||||
(1, 512, 384),
|
||||
(1, 2880, 2880),
|
||||
(2, 512, 384),
|
||||
(2, 2880, 2880),
|
||||
(32, 2880, 2880),
|
||||
(64, 2880, 2880),
|
||||
]
|
||||
|
||||
|
||||
def unshuffle_weight(w: torch.Tensor):
|
||||
first = w[..., ::2]
|
||||
second = w[..., 1::2]
|
||||
return torch.concat((first, second), dim=-1)
|
||||
|
||||
|
||||
def make_weights(dtype, k, n, e):
|
||||
w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda")
|
||||
w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda")
|
||||
|
||||
w2 = torch.randn((e, n, k), dtype=dtype, device="cuda")
|
||||
w2_bias = torch.randn((e, k), dtype=dtype, device="cuda")
|
||||
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
|
||||
w1_bias_tri = w1_bias.clone()
|
||||
w2_bias_tri = w2_bias.clone()
|
||||
w1_bias_tri = w1_bias_tri.to(torch.float32)
|
||||
w2_bias_tri = w2_bias_tri.to(torch.float32)
|
||||
|
||||
# shuffle weights
|
||||
w1_tri = shuffle_weight(w1_tri)
|
||||
w1_bias_tri = shuffle_weight(w1_bias_tri)
|
||||
|
||||
# quant triton_weights
|
||||
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
||||
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1)
|
||||
w1 = unshuffle_weight(w1)
|
||||
|
||||
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
||||
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1)
|
||||
|
||||
num_warps = 8
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w_scale_layout, w_scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
|
||||
)
|
||||
|
||||
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts)
|
||||
w1_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts)
|
||||
w2_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w1_precision_config = PrecisionConfig(
|
||||
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
|
||||
return (
|
||||
w1,
|
||||
w2,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
)
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
if limit is not None:
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
return out_glu * (x_linear + 1)
|
||||
|
||||
|
||||
def torch_moe_impl(
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, K, 2N)
|
||||
w2: torch.Tensor, # (E, N, K)
|
||||
w1_bias: torch.Tensor, # (E, 2N)
|
||||
w2_bias: torch.Tensor, # (E, K)
|
||||
topk_weights: torch.Tensor, # (M, topk)
|
||||
topk_ids: torch.Tensor, # (M, topk)
|
||||
):
|
||||
w1 = w1[topk_ids, ...]
|
||||
w1_bias = w1_bias[topk_ids, ...]
|
||||
hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias
|
||||
hidden_states = swiglu(hidden_states, limit=7)
|
||||
|
||||
w2 = w2[topk_ids, ...]
|
||||
w2_bias = w2_bias[topk_ids, ...]
|
||||
hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias
|
||||
|
||||
# Weighted sum of experts
|
||||
hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def oai_triton_moe_impl(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: "PrecisionConfig",
|
||||
w2_scale: "PrecisionConfig",
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
unfused: bool = False,
|
||||
) -> torch.Tensor:
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
if unfused:
|
||||
fused_experts = UnfusedOAITritonExperts(quant_config)
|
||||
else:
|
||||
fused_experts = OAITritonExperts(quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
|
||||
|
||||
return mk.forward(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation="swigluoai",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("m,n,k", MNK)
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("topk", [4])
|
||||
@pytest.mark.parametrize("unfused", [True, False])
|
||||
def test_oai_triton_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
unfused: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
(
|
||||
w1,
|
||||
w2,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
) = make_weights(dtype, k, n, num_experts)
|
||||
|
||||
x = torch.randn((m, k), dtype=dtype, device="cuda")
|
||||
router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids)
|
||||
|
||||
out = oai_triton_moe_impl(
|
||||
x,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
num_experts,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
unfused,
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005)
|
||||
@ -20,15 +20,24 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
modular_marlin_fused_moe,
|
||||
MarlinExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
TritonExperts,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
|
||||
from .utils import _get_lora_device
|
||||
|
||||
@ -114,15 +123,23 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.base_layer.ensure_moe_quant_config_init()
|
||||
quant_config = self.base_layer.quant_method.moe_quant_config
|
||||
|
||||
m_fused_moe_fn = (
|
||||
modular_triton_fused_moe(
|
||||
quant_config, shared_experts=self.base_layer.shared_experts
|
||||
)
|
||||
if not quant_config.use_mxfp4_w4a16
|
||||
else modular_marlin_fused_moe(
|
||||
quant_config, shared_experts=self.base_layer.shared_experts
|
||||
)
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
m_fused_moe_fn = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
self.base_layer.quant_method.select_gemm_impl(
|
||||
prepare_finalize, self.base_layer
|
||||
),
|
||||
self.base_layer.shared_experts,
|
||||
getattr(self.base_layer, "shared_experts_stream", None),
|
||||
)
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
assert isinstance(
|
||||
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
m_fused_moe_fn.fused_experts, (MarlinExperts, TritonExperts)
|
||||
)
|
||||
|
||||
def fwd_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
@ -376,3 +377,148 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
intermediate_cache=workspace2,
|
||||
a1q_scale=a1q_scale,
|
||||
)
|
||||
|
||||
|
||||
class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
"""
|
||||
A Triton based MoE expert class that operates on expert standard
|
||||
format and explicitly keeps the activation and reduction (moe_sum) steps
|
||||
unfused from the matmul_ogs kernel. This exposes injection points
|
||||
for activation and moe_sum.
|
||||
|
||||
One use case for it is to inject LoRA modules on the activation and moe_sum.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
||||
super().__init__(quant_config)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
workspace1 = (M * topk, N // 2)
|
||||
workspace2 = (M * topk, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor):
|
||||
ops.moe_sum(input, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
routing_data, gather_indx, scatter_indx = self._make_routing_data(
|
||||
topk_ids, topk_weights, local_num_experts
|
||||
)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert (
|
||||
self.quant_config.w1_bias is None
|
||||
or self.quant_config.w1_bias.dtype == torch.float32
|
||||
)
|
||||
assert (
|
||||
self.quant_config.w2_bias is None
|
||||
or self.quant_config.w2_bias.dtype == torch.float32
|
||||
)
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.ndim == 2
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
batch_dim = 1
|
||||
M, K = hidden_states.shape
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
# Note that the output tensor might be in workspace13
|
||||
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
|
||||
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
|
||||
intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2))
|
||||
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
self.quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
precision_config=self.quant_config.w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=None,
|
||||
y=intermediate_cache1,
|
||||
)
|
||||
|
||||
self.activation(
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
# matmul_ogs grouped reduction fuse sum across multiple experts:
|
||||
# y[dst_ind // n_expts_act, :] += x[src_ind, :]
|
||||
# Need to set n_expts_act to 1 to unfuse moe_sum
|
||||
routing_data.n_expts_act = 1
|
||||
|
||||
matmul_ogs(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
self.quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
precision_config=self.quant_config.w2_precision,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=intermediate_cache3,
|
||||
)
|
||||
|
||||
self.moe_sum(intermediate_cache3.view(-1, topk, K), output)
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
@ -83,8 +84,21 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
|
||||
if not current_platform.is_cuda():
|
||||
return Mxfp4Backend.NONE
|
||||
|
||||
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
|
||||
return Mxfp4Backend.MARLIN
|
||||
# If FlashInfer is not available, try either Marlin or Triton
|
||||
triton_kernels_supported = (
|
||||
has_triton_kernels()
|
||||
and is_torch_equal_or_newer("2.8.0")
|
||||
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
|
||||
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
|
||||
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
|
||||
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
|
||||
)
|
||||
if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
|
||||
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
|
||||
return Mxfp4Backend.MARLIN
|
||||
|
||||
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
|
||||
|
||||
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
||||
@ -854,6 +868,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
return MarlinExperts(self.moe_quant_config)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
if self.moe.is_lora_enabled:
|
||||
return UnfusedOAITritonExperts(self.moe_quant_config)
|
||||
return OAITritonExperts(self.moe_quant_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user