[gpt-oss] triton kernel mxfp4 (#22421)

Signed-off-by: <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu 2025-08-08 08:24:07 -07:00 committed by GitHub
parent e5ebeeba53
commit e789cad6b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 755 additions and 9 deletions

3
.gitignore vendored
View File

@ -4,6 +4,9 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
# triton jit
.triton
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, fields
import pytest
import torch
import torch.nn.functional as F
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
upcast_from_mxfp)
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
BatchedOAITritonExperts, triton_kernel_moe_forward)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.utils import shuffle_weight
from vllm.utils import round_up
def deshuffle(w: torch.Tensor):
first = w[..., ::2]
second = w[..., 1::2]
deshuffled = torch.concat((first, second), dim=-1)
return deshuffled
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
randbits = [torch.randperm(E) for _ in range(M)]
x_list = [
(-1)**i *
((16384 +
((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
for i, bits in enumerate(randbits)
]
exp_data = torch.stack(x_list).to(
device="cuda") # simulating gate_output (M, E)
# create input tensor
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda")
w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda")
w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda")
w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda")
exp_data_tri = exp_data.clone()
x_tri = x.clone()
w1_tri = w1.clone()
w2_tri = w2.clone()
w1_bias_tri = w1_bias.clone()
w2_bias_tri = w2_bias.clone()
w1_bias_tri = w1_bias_tri.to(torch.float32)
w2_bias_tri = w2_bias_tri.to(torch.float32)
dtype_dict = {
"bf16": torch.bfloat16,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2
}
x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
if w_dtype != "mx4":
# simulate quantization support on reference impl
w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16)
w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16)
# triton moe kernel use transposed shape for matmul
w1_tri = w1_tri.transpose(-2, -1)
w2_tri = w2_tri.transpose(-2, -1)
# shuffle weights
w1_tri = shuffle_weight(w1_tri)
w1_bias_tri = shuffle_weight(w1_bias_tri)
# quant triton_weights
x_tri = x.to(dtype_dict[a_dtype])
if w_dtype != "mx4":
pytest.skip("NYI")
else: # quantize to mx4
# careful on the padding here, the activation padding need to be
# multiple of 64, the actual engine is not implemented
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
w2_bottom_pad = w1_right_pad // 2
w2_right_pad = w1_bottom_pad
x_pad = w1_bottom_pad
w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
mode="constant",
value=0)
w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
mode="constant",
value=0)
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
mode="constant",
value=0)
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
mode="constant",
value=0)
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1)
w_scale_layout, w_scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
**w_layout_opts)
w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri),
w_scale_layout, **w_scale_layout_opts)
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
**w_layout_opts)
w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri),
w_scale_layout, **w_scale_layout_opts)
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))
pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))
# tucuate so the rest can run properly
w1 = w1[..., :K, :2 * N]
w2 = w2[..., :N, :K]
w1 = deshuffle(w1)
w1 = w1.transpose(-1, -2).contiguous()
w2 = w2.transpose(-1, -2).contiguous()
return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri,
exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2)
@dataclass
class ModelConfig:
num_hidden_layers: int = 36
num_experts: int = 128
experts_per_token: int = 4
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
head_dim: int = 64
num_attention_heads: int = 64
num_key_value_heads: int = 8
sliding_window: int = 128
initial_context_length: int = 4096
rope_theta: float = 150000.0
rope_scaling_factor: float = 32.0
rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
x_glu = x_glu.clamp(max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
if limit is not None:
x_linear = x_linear.clamp(min=-limit, max=limit)
return out_glu * (x_linear + 1)
def oai_moe_forward(
hidden_states: torch.Tensor, # (M, K)
w1: torch.Tensor, # (E, 2N)
w1_bias: torch.Tensor, # (E, 2N, K)
w2: torch.Tensor, # (E, K, N)
w2_bias: torch.Tensor, # (E, N)
gating_output: torch.Tensor, # (M, E)
topk: int):
# model.py 309:330, assuming gating and norm
t = hidden_states
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
expert_indices = experts.indices
# MLP #1
mlp1_weight = w1[expert_indices, ...]
mlp1_bias = w1_bias[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, limit=7)
# MLP #2
mlp2_weight = w2[expert_indices, ...]
mlp2_bias = w2_bias[expert_indices, ...]
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
t += mlp2_bias
# Weighted sum of experts
t = torch.einsum("bec,be->bc", t, expert_weights)
return t
@dataclass
class Case:
a_dtype: str
w_dtype: str
@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")
]
],
)
@pytest.mark.parametrize("num_token", [2])
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
def test_equiv(num_token, a_dtype, w_dtype, tp):
M = num_token
E = ModelConfig.num_experts
K = ModelConfig.hidden_size
N = ModelConfig.intermediate_size // tp
topk = ModelConfig.experts_per_token
x, w1, w1_bias, w2, w2_bias, exp_data, \
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\
w2_bias_tri, pc1, pc2 = init_compute_data(
M, K, N, E, a_dtype, w_dtype, num_warps=8)
out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri,
w1=w1_tri,
w2=w2_tri,
gating_output=exp_data_tri,
topk=topk,
renormalize=True,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2)
out_triton_monolithic = out_triton_monolithic[..., :K]
out_ref = oai_moe_forward(hidden_states=x,
w1=w1,
w1_bias=w1_bias,
w2=w2,
w2_bias=w2_bias,
gating_output=exp_data,
topk=topk)
assert_close(ref=out_ref,
tri=out_triton_monolithic,
maxtol=0.025,
rmstol=0.005)
def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor,
topk: int, renormalize: bool, w1_bias: torch.Tensor,
w2_bias: torch.Tensor, w1_precision: PrecisionConfig,
w2_precision: PrecisionConfig) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedOAITritonExperts(
None,
max_num_tokens=max_num_tokens,
num_dispatchers=1,
w1_precision=w1_precision,
w2_precision=w2_precision,
),
)
extra_expert_args = {
"w1_bias": w1_bias,
"w2_bias": w2_bias,
}
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
return fused_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
extra_expert_args=extra_expert_args,
)
@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")
]
],
)
@pytest.mark.parametrize("num_token", [64])
@pytest.mark.parametrize("ep", [1, 2, 4, 8])
def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
M = num_token
E = ModelConfig.num_experts // ep
K = ModelConfig.hidden_size
N = ModelConfig.intermediate_size
topk = ModelConfig.experts_per_token
x, w1, w1_bias, w2, w2_bias, exp_data, \
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \
w2_bias_tri, pc1, pc2 = init_compute_data(
M, K, N, E, a_dtype, w_dtype, num_warps=4)
out_tri = batched_moe(a=x_tri,
w1=w1_tri,
w2=w2_tri,
gating_output=exp_data_tri,
topk=topk,
renormalize=True,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2)
out_tri = out_tri[..., :K]
out_ref = oai_moe_forward(hidden_states=x,
w1=w1,
w1_bias=w1_bias,
w2=w2,
w2_bias=w2_bias,
gating_output=exp_data,
topk=topk)
assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)
def test_unit_shuffle():
N = ModelConfig.intermediate_size
K = ModelConfig.hidden_size
m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda")
x = torch.randn(K, dtype=torch.bfloat16, device="cuda")
m_shuffled = shuffle_weight(m)
out_ref = x @ m
out_ref = swiglu(out_ref, limit=1.0)
out = x @ m_shuffled
out = triton_kernels.swiglu.swiglu_torch(
out,
alpha=1.702,
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0))
assert_close(ref=out_ref, tri=out)

View File

@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
if True:
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
PrecisionConfig, matmul_ogs)
from triton_kernels.routing import routing
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
w2, # Tensor or triton_kernels.Tensor
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision=None, # PrecisionConfig or None
w2_precision=None, # PrecisionConfig or None
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
routing_data, gather_idx, scatter_idx = routing(gating_output,
topk,
sm_first=not renormalize)
return triton_kernel_fused_experts(
None,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=w1_precision,
w2_precision=w2_precision,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
# This is a triton implementation of the fused_experts function
def triton_kernel_fused_experts(
output_tensor: torch.Tensor,
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
w2, # Tensor or triton_kernels.Tensor
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision=None, # PrecisionConfig or None
w2_precision=None, # PrecisionConfig or None
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert w1_bias is None or w1_bias.dtype == torch.float32
assert w2_bias is None or w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit), 2)
gammas = routing_data.gate_scal if routing_data else None
intermediate_cache1 = matmul_ogs(
hidden_states,
w1,
w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
w2,
w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
return intermediate_cache3
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config, max_num_tokens: int, num_dispatchers: int,
w1_precision: PrecisionConfig, w2_precision: PrecisionConfig):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# workspace are allocated inside the kernel
assert a.dim() == 2
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
workspace2 = (0, 0, 0)
output = (num_experts, max_num_tokens * num_dp, N)
return (output, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
["w1_bias", "w2_bias"]))
return triton_kernel_fused_experts(
output,
hidden_states,
w1,
w2,
None,
None,
None,
activation=activation,
apply_router_weight_on_input=False,
use_fp8_w8a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,
a2_scale=a2_scale)

View File

@ -36,7 +36,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
has_triton_kernels, is_torch_equal_or_newer, round_up)
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike():
@ -723,10 +723,17 @@ class FusedMoE(torch.nn.Module):
self.global_num_experts = num_experts + num_redundant_experts
# we padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4" and (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
if quant_config and quant_config.get_name() == "mxfp4":
if not is_torch_equal_or_newer("2.8.0"):
raise RuntimeError("Mxfp4 on hopper requires torch >= 2.8.0")
if current_platform.is_device_capability(
90) and not has_triton_kernels():
raise NotImplementedError(
"Triton kernels must be installed for mxfp4 on hopper")
if (current_platform.is_rocm()
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config

View File

@ -8,16 +8,19 @@ from torch.nn.parameter import Parameter
from vllm import envs
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_can_support_mxfp4)
_can_support_mxfp4, _swizzle_mxfp4)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import next_power_of_2, round_up
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
@ -39,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 100
return 90
@classmethod
def get_name(cls) -> QuantizationMethods:
@ -100,11 +103,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
@ -303,7 +313,41 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
self.num_experts, -1),
requires_grad=False)
return
else:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
@ -404,3 +448,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
True, # do finalize
)[0]
return trtllm_gen_output
else:
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_precision=self.w13_precision_config,
w2_precision=self.w2_precision_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -4,11 +4,55 @@ from typing import Callable, Optional
import torch
from vllm.utils import direct_register_custom_op
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
OCP_MX_BLOCK_SIZE = 32
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
"""
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.numerics import InFlexData
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
if (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and not is_torch_equal_or_newer("2.8.1")):
logger.warning_once(
"Mxfp4 on hopper is running on torch < 2.8.1, "
"this cause swizling to be disabled, which may "
"cause performance degradation. Please upgrade to torch nightly")
value_layout, value_layout_opts = StridedLayout, dict()
scale_layout, scale_layout_opts = StridedLayout, dict()
else:
value_layout, value_layout_opts = \
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
scale_layout, scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))
if current_platform.is_cuda() and \
current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
value_layout, **value_layout_opts)
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
**scale_layout_opts)
return quant_tensor, InFlexData(), scale
def _can_support_mxfp4(use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,

View File

@ -11,6 +11,27 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
# Shuffle weight along the last dimension so that
# we folded the weights to adjance location
# Example:
# input:
# [[1, 2, 3, 4, 5, 6],
# [7, 8, 9, 10, 11, 12]]
# output:
# [[1, 4, 2, 5, 3, 6],
# [7, 10, 8, 11, 9, 12]]
# This will be used together with triton swiglu kernel
shape = w.shape
N = shape[-1]
first = w[..., :N // 2]
second = w[..., N // 2:]
stacked = torch.stack((first, second), dim=-1)
w_shuffled = stacked.reshape(shape)
return w_shuffled
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,

View File

@ -3254,6 +3254,12 @@ def has_deep_gemm() -> bool:
return _has_module("deep_gemm")
def has_triton_kernels() -> bool:
"""Whether the optional `triton_kernels` package is available."""
return _has_module("triton_kernels")
def set_process_title(name: str,
suffix: str = "",
append: bool = False) -> None: