[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-09-17 19:43:31 -04:00 committed by GitHub
parent e6585ddb45
commit 5963b98b46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 2698 additions and 2526 deletions

View File

@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.scalar_type import scalar_types
@ -140,6 +144,12 @@ def bench_run(
a_fp8_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
for _ in range(num_repeats):
fused_experts(
a,
@ -147,10 +157,7 @@ def bench_run(
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
quant_config=quant_config,
)
def run_cutlass_moe_fp4(
@ -172,25 +179,27 @@ def bench_run(
device: torch.device,
num_repeats: int,
):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
quant_config=quant_config,
)
def run_cutlass_from_graph(
@ -211,26 +220,29 @@ def bench_run(
e: int,
device: torch.device,
):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_alphas,
a2_gscale=a2_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
quant_config=quant_config,
)
def run_triton_from_graph(
@ -246,16 +258,18 @@ def bench_run(
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
return fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
quant_config=quant_config,
)
def replay_graph(graph, num_repeats):

View File

@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
@ -96,6 +97,11 @@ def bench_run(
a_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
)
for _ in range(num_repeats):
fused_experts(
a,
@ -103,10 +109,7 @@ def bench_run(
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_config=quant_config,
)
def run_cutlass_moe(
@ -125,6 +128,12 @@ def bench_run(
per_act_token: bool,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)
for _ in range(num_repeats):
cutlass_moe_fp8(
a,
@ -132,14 +141,11 @@ def bench_run(
w2,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
quant_config=quant_config,
)
def run_cutlass_from_graph(
@ -156,6 +162,12 @@ def bench_run(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
@ -165,14 +177,11 @@ def bench_run(
w2_q,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
quant_config=quant_config,
)
def run_triton_from_graph(
@ -185,6 +194,11 @@ def bench_run(
w2_scale: torch.Tensor,
a_scale: torch.Tensor,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
@ -194,10 +208,7 @@ def bench_run(
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_config=quant_config,
)
def replay_graph(graph, num_repeats):

View File

@ -14,6 +14,10 @@ import ray
import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
@ -134,43 +138,36 @@ def benchmark_config(
def run():
from vllm.model_executor.layers.fused_moe import override_config
if use_fp8_w8a8:
quant_dtype = torch.float8_e4m3fn
elif use_int8_w8a16:
quant_dtype = torch.int8
else:
quant_dtype = None
quant_config = FusedMoEQuantConfig.make(
quant_dtype=quant_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)
with override_config(config):
if use_deep_gemm:
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, False
)
return fused_experts(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
allow_deep_gemm=True,
)
else:
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm
)
return fused_experts(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=True,
quant_config=quant_config,
allow_deep_gemm=use_deep_gemm,
)
# JIT compilation & warmup
run()
@ -414,7 +411,7 @@ class BenchmarkWorker:
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
@ -547,7 +544,7 @@ def save_configs(
block_quant_shape: list[int],
save_dir: str,
) -> None:
dtype_str = get_config_dtype_str(
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)

View File

@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo
@ -40,7 +40,7 @@ class Config:
E: int
topks: Union[list[int], int]
dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig]
quant_config: Optional[TestMoEQuantConfig]
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
@ -52,7 +52,7 @@ class Config:
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
self.quant_config = TestMoEQuantConfig(None, False, False, None)
def describe(self) -> str:
s = ""
@ -275,21 +275,19 @@ class WeightTensors:
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device())
device = torch.cuda.current_device()
self.w1 = self.w1.to(device=device)
self.w2 = self.w2.to(device=device)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to(
device=torch.cuda.current_device())
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
if self.w1_scale is not None:
self.w1_scale = self.w1_scale.to(device=device)
if self.w2_scale is not None:
self.w2_scale = self.w2_scale.to(device=device)
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
self.w1_gs = self.w1_gs.to(device=device)
if self.w2_gs is not None:
self.w2_gs = self.w2_gs.to(device=device)
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
@ -297,20 +295,12 @@ class WeightTensors:
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
w1_scale, w2_scale = (None, None)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :]
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
w1_scale = self.w1_scale[
s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[
s:e, :, :] if self.w2_scale is not None else None
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
@ -323,7 +313,8 @@ class WeightTensors:
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
per_out_ch_quant=config.
is_per_act_token_quant, # or config.is_per_out_ch_quant
)
return WeightTensors(w1=w1,
w2=w2,
@ -342,8 +333,6 @@ class RankTensors:
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor]
quant_config: Optional[FusedMoEQuantConfig]
def describe(self):
s = ""
s += "== Rank Tensors: \n"
@ -426,7 +415,6 @@ class RankTensors:
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_config=config.quant_config,
)
@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
and config.supports_apply_weight_on_input())
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
weights: WeightTensors,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
@ -548,20 +542,20 @@ def make_modular_kernel(
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
quant_config=config.quant_config,
max_num_tokens=next_power_of_2(config.M),
)
# make modular kernel
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe)
config.all2all_backend(), moe,
quant_config)
fused_experts = make_fused_experts(
config.fused_experts_type,
moe,
quant_config,
prepare_finalize.num_dispatchers(),
weights.w1_gs,
weights.w2_gs,
config.N,
)
modular_kernel = mk.FusedMoEModularKernel(
@ -583,12 +577,38 @@ def run_modular_kernel(
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config, weights)
if config.quant_dtype == "nvfp4":
gscale = _make_gscale(config.num_local_experts)
else:
gscale = None
quant_config = FusedMoEQuantConfig.make(
config.quant_dtype,
w1_scale=rank_weights.w1_scale,
w2_scale=rank_weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
g1_alphas=(1 / rank_weights.w1_gs)
if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs)
if rank_weights.w2_gs is not None else None,
a1_gscale=gscale,
a2_gscale=gscale,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_act_token_quant,
per_out_ch_quant=config.is_per_out_ch_quant,
)
mk = make_modular_kernel(config, vllm_config, quant_config)
# impls might update the tensor in place
hidden_states = rank_tensors.hidden_states.clone()
topk_ids = rank_tensors.topk_ids.to(
mk.prepare_finalize.topk_indices_dtype())
mk_kwargs = {
"hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place
hidden_states,
"w1":
rank_weights.w1,
"w2":
@ -596,15 +616,9 @@ def run_modular_kernel(
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
topk_ids,
"expert_map":
rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":

View File

@ -10,7 +10,8 @@ import torch
from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG)
from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str):
quant_config_dict = config_dict['quant_config']
del config_dict['quant_config']
if quant_config_dict is None:
quant_config = FusedMoEQuantConfig(None)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict

View File

@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@dataclass
class TestMoEQuantConfig:
quant_dtype: Union[torch.dtype, str, None]
per_out_ch_quant: bool
per_act_token_quant: bool
block_shape: Optional[list[int]]
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
]
common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"]
nvfp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn]
@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_experts(
FlashInferExperts,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
@ -306,39 +314,39 @@ if cutlass_fp4_supported():
register_experts(
CutlassExpertsFp4,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
MK_QUANT_CONFIGS = [
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
None,
# per-channel / per-column weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
# per-channel / per-column weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
# per-tensor weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
# per-tensor weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
# block-quantized weights and 128 block per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1,
a1_gscale=_make_gscale(moe.num_local_experts),
)
use_dp=moe.moe_parallel_config.dp_size > 1)
else:
return MoEPrepareAndFinalizeNoEP()
@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
return t[s:e]
def make_cutlass_strides(
e: int,
n: int,
k: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
return ab_strides1, ab_strides2, c_strides1, c_strides2
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
num_dispatchers: int,
w1_gs: Optional[torch.Tensor],
w2_gs: Optional[torch.Tensor],
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
"quant_config": quant_config,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
@ -422,8 +429,8 @@ def make_fused_experts(
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
print("Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
@ -437,62 +444,50 @@ def make_fused_experts(
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
}
"out_dtype": moe.in_dtype,
} | quant_kwargs
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
}
} | quant_kwargs
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
return experts

View File

@ -6,6 +6,8 @@ import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
rank=0,
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
per_act_token_quant=False,
block_shape=BLOCK_SIZE,
)
# triton (reference)
triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=True,
per_act_token_quant=False,
block_shape=BLOCK_SIZE,
quant_config=quant_config,
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
w1_scale=w1_s,
w2_scale=w2_s,
global_num_experts=E,
)
@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
deepgemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
block_shape=BLOCK_SIZE,
per_act_token_quant=False,
quant_config=quant_config,
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
w1_scale=w1_s,
w2_scale=w2_s,
global_num_experts=E,
)

View File

@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
out_shape = (num_experts, max_tokens_per_expert, N)
@ -250,7 +250,7 @@ def test_fused_moe_batched_experts(
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
if input_scales and quant_dtype is not None:

View File

@ -4,7 +4,7 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
w1, w2, quant_config = make_test_quant_config(
E,
N,
K,
dtype,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size,
)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(quant_config)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a,
w1,
w2,
w1_s,
w2_s,
quant_config.w1_scale,
quant_config.w2_scale,
topk_weights,
topk_ids,
block_size,
)
out = fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
m_out = m_fused_moe(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol = 0.035 if M < 40000 else 0.039
# 0.039 only needed for M >= 8192
tol = 0.035 if M < 8192 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_out_ch_quant=False,
block_shape=block_size,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and

View File

@ -4,12 +4,12 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
@ -50,7 +50,7 @@ MNK_FACTORS = [
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 7168),
(2048, 4096, 4096),
]
E = [8, 24]
@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.int8,
per_act_token_quant=False,
block_shape=block_size)
w1, w2, quant_config = make_test_quant_config(
E,
N,
K,
dtype,
quant_dtype=torch.int8,
per_act_token_quant=False,
block_shape=block_size,
)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
quant_config.w2_scale, score, topk,
block_size)
# Check results

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import dataclasses
from math import prod
from typing import Optional
@ -9,6 +10,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def slice_experts():
slice_params = [
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
"c_strides2", "w1_scale", "w2_scale"
"c_strides2"
]
full_tensors = {
k: v
@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
if k in slice_params and k in cutlass_moe_kwargs
}
quant_config = cutlass_moe_kwargs["quant_config"]
for i in range(0, num_experts, num_local_experts):
s, e = i, i + num_local_experts
@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
for k, t in full_tensors.items():
cutlass_moe_kwargs[k] = t[s:e]
new_quant_config = copy.deepcopy(quant_config)
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
cutlass_moe_kwargs["quant_config"] = new_quant_config
yield cutlass_moe_kwargs
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
]
])
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=moe_tensors.w1_scale,
w2_scale=moe_tensors.w2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
# Set to moe_tensors.a_scale iff static scales + per tensor.
# This is not currently being tested.
a1_scale=None,
)
kwargs = {
'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
'quant_config': quant_config,
}
num_experts = moe_tensors.w1.size(0)
@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph(
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly"
number_local_experts = e // ep_size
else:
number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
number_local_experts)
per_out_ch, number_local_experts)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph(
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
per_act_token, per_out_ch)
torch.cuda.synchronize()
graph.replay()

View File

@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
Return weights w1q, w2q, w1_scale, w2_scale
"""
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
_) = make_test_weights(e, n, k, torch.bfloat16,
_) = make_test_weights(e,
n,
k,
torch.bfloat16,
torch.float8_e4m3fn,
block_size)
block_shape=block_size)
return w1q, w2q, w1_scale, w2_scale
@ -130,10 +135,11 @@ class TestTensors:
config=config)
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
max_tokens_per_rank: int, dp_size: int,
hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
def make_ll_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
def make_ht_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype,
block_shape=test_config.block_size)
fused_experts = DeepGemmExperts()
fused_experts = DeepGemmExperts(quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int,
test_tensors: TestTensors) -> FusedMoEModularKernel:
def make_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config)
test_config=test_config,
quant_config=quant_config)
else:
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
q_dtype, test_config)
mk = make_ht_modular_kernel(pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config)
return mk
@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
# Low-Latency kernels can't dispatch scales.
a1_scale=(None if test_config.low_latency else
test_tensors.rank_token_scales),
block_shape=test_config.block_size,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg=pg,
pgi=pgi,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors)
# Low-Latency kernels can't dispatch scales.
a1_scale = (None
if test_config.low_latency else test_tensors.rank_token_scales)
test_tensors=test_tensors,
quant_config=quant_config)
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=a1_scale,
a2_scale=None,
apply_router_weight_on_input=False)
return out
@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
)
return fused_experts(
hidden_states=a,
w1=w1,
@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
quant_config=quant_config,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm=False)

View File

@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@ -129,11 +130,9 @@ def make_modular_kernel(
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
@ -159,24 +158,14 @@ def make_modular_kernel(
num_dispatchers = pgi.world_size // dp_size
if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
assert not quant_config.per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
num_dispatchers=num_dispatchers,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
quant_config=quant_config,
)
else:
fused_experts = TritonExperts(
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
fused_experts = TritonExperts(quant_config=quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
@ -217,11 +206,6 @@ def deep_ep_moe_impl(
if is_quantized:
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
@ -236,6 +220,19 @@ def deep_ep_moe_impl(
rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end]
quant_config = FusedMoEQuantConfig.make(
q_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token_quant,
a1_scale=rank_token_scales_chunk,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
@ -245,12 +242,6 @@ def deep_ep_moe_impl(
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=rank_token_scales_chunk,
a2_scale=None,
apply_router_weight_on_input=False)
if not skip_result_store:
@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@requires_deep_ep
def test_deep_ep_moe(
dtype: torch.dtype,
mnk: tuple[int, int, int],
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
@ -424,7 +417,6 @@ def test_deep_ep_moe(
):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool):
def test_low_latency_deep_ep_moe(
dtype: torch.dtype,
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool,
):
low_latency_mode = True
m, n, k = mnk
if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):

View File

@ -11,6 +11,8 @@ import math
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
)
# triton reference
out_triton = fused_experts(
hidden_states=tokens_bf16,
@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
quant_config=quant_config,
allow_deep_gemm=False,
)
@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
quant_config=quant_config,
allow_deep_gemm=True,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
# Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path.
# Note: N <= 512 will disable the deepgemm path due to performance issues.
MNKs = [
(1024, 768, 128),
(1024, 768, 512),
@ -144,15 +144,15 @@ TOPKS = [2, 6]
NUM_EXPERTS = [32]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_DEEP_GEMM", "1")
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe")
@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
_spy_deep_gemm_moe_fp8)
m, n, k = mnk
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")

View File

@ -6,6 +6,8 @@ import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts(
td.hidden_states,
td.w13_quantized,
@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
quant_config=quant_config,
)
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts(
td.hidden_states,
td.w13_quantized,
@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
quant_config=quant_config,
)
td.layer.dp_size = 1

View File

@ -3,7 +3,7 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
@ -41,7 +41,6 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
#@pytest.mark.parametrize("e", [128, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
quant_blocksize = 16
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
)
w1_q, w2_q, quant_config = make_test_quant_config(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
a1_gscale=a1_gs,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
g2_alphas=(1 / w2_gs),
out_dtype=dtype,
quant_dtype="nvfp4",
))
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
)
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w1_scale=w1_blockscale,
w2=w2_q,
w2_scale=w2_blockscale,
a1_scale=a1_gs,
a2_scale=a2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)

View File

@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
quant_config = FusedMoEQuantConfig.make(
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2,
)
out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri,
w1=w1_tri,
@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
gating_output=exp_data_tri,
topk=topk,
renormalize=True,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2,
quant_config=quant_config,
)
out_triton_monolithic = out_triton_monolithic[..., :K]
@ -336,6 +341,13 @@ def batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
w1_precision=w1_precision,
w2_precision=w2_precision,
w1_bias=w1_bias,
w2_bias=w2_bias,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(
max_num_tokens,
@ -344,19 +356,12 @@ def batched_moe(
rank=0,
),
BatchedOAITritonExperts(
None,
max_num_tokens=max_num_tokens,
num_dispatchers=1,
w1_precision=w1_precision,
w2_precision=w2_precision,
quant_config=quant_config,
),
)
extra_expert_args = {
"w1_bias": w1_bias,
"w2_bias": w2_bias,
}
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
return fused_experts(
@ -365,7 +370,6 @@ def batched_moe(
w2,
topk_weight,
topk_ids,
extra_expert_args=extra_expert_args,
)

View File

@ -12,7 +12,6 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
run_modular_kernel)
from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config)
@ -55,7 +55,7 @@ def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
base_config: Config,
weights: WeightTensors,
verbose: bool,
):
@ -63,42 +63,44 @@ def rank_worker(
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
if base_config.fused_moe_chunk_size is not None:
assert (
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device
weights.to_current_device()
Ms = config.Ms
Ms = base_config.Ms
assert isinstance(Ms, list)
TOPKs = config.topks
TOPKs = base_config.topks
assert isinstance(TOPKs, list)
exceptions = []
count = 0
for m, topk in product(Ms, TOPKs):
# override m and topk
config = copy.deepcopy(base_config)
config.Ms = m
config.topks = topk
try:
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
count = count + 1
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
rank_tensors = RankTensors.make(config, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
ref_out = reference_moe_impl(config, weights, rank_tensors)
if config.quant_dtype == "nvfp4":
atol = 1e-1
rtol = 1e-1
atol = 1e-1 if config.K < 4096 else 2e-1
rtol = 1e-1 if config.K < 4096 else 2e-1
else:
atol = 3e-2
rtol = 3e-2
@ -132,7 +134,7 @@ Ms = [32, 64]
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
Ns = [2048]
Ns = [1024]
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool:
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[FusedMoEQuantConfig],
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu(
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[FusedMoEQuantConfig],
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):

View File

@ -15,11 +15,14 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
@ -187,14 +190,9 @@ def test_fused_moe(
#
# Setup test functions
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=None)
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
def m_fused_moe(
a: torch.Tensor,
@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else:
e_map = None
if weight_bits == 4:
quant_config_builder = int4_w4a16_moe_quant_config
else:
assert weight_bits == 8
quant_config_builder = int8_w8a16_moe_quant_config
quant_config = quant_config_builder(w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a,
w1_qweight,
@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
score,
topk,
renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=e,
expert_map=e_map,
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
quant_config=quant_config)
torch_output = torch_moe(a,
w1_ref,
w2_ref,

View File

@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
per_out_ch_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
assert w1_blockscale is not None
assert w2_blockscale is not None
quant_config = nvfp4_moe_quant_config(
g1_alphas=(1 / w1_gs),
g2_alphas=(1 / w2_gs),
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
)
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_q,
w1_blockscale=w1_blockscale,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
g2_alphas=(1 / w2_gs),
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=quant_config,
m=m,
n=n,
k=k,

View File

@ -9,6 +9,8 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@ -143,10 +145,16 @@ def pplx_cutlass_moe(
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
experts = CutlassBatchedExpertsFp8(
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides2, c_strides1, c_strides2,
fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
@ -167,10 +175,7 @@ def pplx_cutlass_moe(
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
)
torch.cuda.synchronize()

View File

@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
]
PPLX_COMBOS = [
# TODO: figure out why this fails, seems to be test problem
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
@ -360,18 +360,18 @@ def pplx_prepare_finalize(
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk,
a1_scale,
a2_scale,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
FusedMoEQuantConfig(
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant,
False,
block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
a1_scale=a1_scale,
a2_scale=a2_scale,
),
)
@ -540,20 +540,6 @@ def pplx_moe(
topk_ids = topk_ids.to(dtype=torch.uint32)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
@ -567,6 +553,28 @@ def pplx_moe(
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
@ -585,10 +593,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
@ -605,10 +609,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
@ -820,7 +820,7 @@ def test_pplx_moe_slow(
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
args["w1"] = w1
args["w2"] = w2

View File

@ -7,10 +7,12 @@ import itertools
import pytest
import torch
from tests.kernels.moe.utils import fused_moe
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (9, 0):
@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
score,
topk,
renormalize=False,
use_fp8_w8a8=True, # using fp8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
quant_config=fp8_w8a8_moe_quant_config(
per_act_token_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
),
)
# Check results

View File

@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@ -34,18 +35,22 @@ def triton_moe(
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
quant_config=quant_config)
def batched_moe(
@ -64,6 +69,16 @@ def batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
@ -72,21 +87,11 @@ def batched_moe(
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def naive_batched_moe(
@ -105,6 +110,16 @@ def naive_batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
@ -113,21 +128,11 @@ def naive_batched_moe(
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
@ -216,7 +221,7 @@ def make_test_weight(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
@ -228,7 +233,7 @@ def make_test_weight(
w_gs_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
@ -258,16 +263,16 @@ def make_test_weights(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return (
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
)
@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_test_quant_config(
e: int,
n: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype,
quant_dtype,
per_out_ch_quant=per_act_token_quant,
block_shape=block_shape,
)
# Hacky/trivial scales for nvfp4.
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_scale = a1_gscale
a2_scale = a2_gscale
else:
a1_scale = None
a2_scale = None
return w1, w2, FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
renormalize: bool = False,
quant_config: Optional[FusedMoEQuantConfig] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config)
# CustomOp?
class BaselineMM(torch.nn.Module):

View File

@ -8,7 +8,8 @@ import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8)
from vllm.platforms import current_platform
@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
topk_ids):
"""This function performs fused moe with per-column int8 quantization
using native torch."""
@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
out = fused_moe(
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
topk_weights, topk_ids)
quant_config = FusedMoEQuantConfig.make(
torch.int8,
per_act_token_quant=True,
block_shape=None,
w1_scale=w1_s,
w2_scale=w2_s,
)
out = fused_experts(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True, # Using int8-w8a8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
topk_weights,
topk_ids,
quant_config=quant_config,
)
# Check results

View File

@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
@ -36,6 +37,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"activation_without_mul",
"override_config",
"get_config",
]
@ -43,7 +45,6 @@ __all__ = [
if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
@ -56,13 +57,12 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
TritonExperts, fused_experts, fused_topk, get_config_file_name,
grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
__all__ += [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",

View File

@ -8,6 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
def __init__(self,
max_num_tokens: int,
num_dispatchers: int,
block_shape: list[int],
per_act_token_quant=False):
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers.
block_shape: Block quantization block shape.
per_act_token_quant: Per activation token quantization flag.
quant_config: Quantization configuration
"""
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
super().__init__(quant_config)
assert self.block_shape == deep_gemm_block_shape()
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale),
workspace1, expert_num_tokens, expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
workspace1, expert_num_tokens)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
expert_num_tokens, expected_m)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale),
output, expert_num_tokens, expected_m)

View File

@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
allow_deep_gemm: bool = False):
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
))
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(quant_config)
self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape,
quant_config=self.quant_config,
)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8
and self.block_shape
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape())
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
block_shape=self.block_shape, # type: ignore[arg-type]
quant_config=self.quant_config,
) if self.allow_deep_gemm else None
assert (self.batched_deep_gemm_experts is not None
@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
activation, global_num_experts, expert_map, a1q_scale,
workspace13, workspace2, expert_tokens_meta,
apply_router_weight_on_input)

View File

@ -1,103 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import cdiv
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.utils import cdiv, has_triton_kernels
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if TYPE_CHECKING and has_triton_kernels:
from triton_kernels.matmul_ogs import PrecisionConfig
logger = init_logger(__name__)
def _get_quant_config_quantization_args(
quant_config: Optional[QuantizationConfig],
prop_name: str,
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(prop_name)
else:
return None
def get_quant_config_input_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config,
"input_activations")
def get_quant_config_weight_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config, "weights")
def get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]:
def _get_config_dtype_str(
dtype: torch.dtype,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
) -> Optional[str]:
"""
Return a string used to construct the filename that contains the
tuning info for a particular quantization scheme. See
try_get_optimal_moe_config in fused_moe.py.
"""
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_mxfp4_w4a4:
return "mxfp4"
return "mxfp4_w4a4"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
def _quant_flags_to_group_shape(
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[Optional[GroupShape], Optional[GroupShape]]:
"""
Convert MoE quantization flags into more generic GroupShapes.
"""
a_shape: Optional[GroupShape]
w_shape: Optional[GroupShape]
if block_shape is not None:
assert not per_act_token_quant
assert not per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first
# dim should be 1.
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
else:
w_shape = None
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
if per_act_token_quant:
a_shape = GroupShape.PER_TOKEN
if per_out_ch_quant:
w_shape = GroupShape.PER_TOKEN
return a_shape, w_shape
@dataclass
class FusedMoEQuantDesc:
"""
A quantization descriptor for fused MoE ops. This class can describe
either activations or weights.
"""
# The quantized type of this parameters. None means unquantized or
# already quantized.
# TODO (bnell): use scalar_type instead of Union.
dtype: Union[torch.dtype, str, None] = None
# A field that describes the quantization group shape, from quant_utils.py.
# * (-1, -1) for per-tensor quantization
# * (1, -1) for per-row quantization
# * (-1, 1) for per-column quantization
# * (128, 128) for 128x128 deepseek style block quantization
# * (1, 128) for deepseek style activation quantization
# (i.e. per-token-per-group)
shape: Optional[GroupShape] = None
# Quantization scales.
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
# Quantization alphas or gscales, used for nvfp4 types.
# TODO(bnell): put some of these in subclasses
alpha_or_gscale: Optional[torch.Tensor] = None
# Zero points for int4/int8 types
zp: Optional[torch.Tensor] = None
# Biases for GPT triton MoE
bias: Optional[torch.Tensor] = None
# TODO(bnell): have subclasses for specific moe methods?
# e.g. for specific arguments bias, precision, etc.
@dataclass
class FusedMoEQuantConfig:
# The post quantization activation type.
# TODO (bnell): use scalar_type instead of Union.
quant_dtype: Union[torch.dtype, str, None] = None
per_act_token_quant: bool = False
per_out_ch_quant: bool = False
block_shape: Optional[list[int]] = None
"""
The FusedMoEQuantConfig contains all the quantization parameters for
a single FusedMoEMethodBase operation. It consists of four
FusedMoEQuantDescs, one for each activation and set of weights.
# TODO: add col major flag?
# add detailed quant info for input, intermediates, weights, etc?
Each FusedMoEMethodBase must implement a get_fused_moe_quant_config
method to construct a FusedMoEQuantConfig for use with that class.
FusedMoEQuant configs are only used for modular kernels, fused_experts
(from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and
triton_kernel_moe_forward. Other MoE methods can ignore the
FusedMoEQuantConfig (for now) and hardcode it to None.
There are currently some restrictions on what can be expressed:
- Most MoE ops only support similar quantization strategies for
each parameter, e.g. both weights must have the same GroupShape
and both activations must share the same GroupShape. One exception to
this is the cutlass moe which allows per channel quantization on the
outputs. Note: this restrictions are not always rigorously checked.
- Not all fused MoE functions support all the parameters, e.g. zero points,
global scales, alphas and biases are not universally supported.
- Fully general GroupShapes are not allowed. Activations only support
per token, per tensor or K-blocked.
- Weights are not required to have a GroupShape since they have already
been quantized.
Other notes:
- PrecisionConfigs are specific to GPT OSS Triton.
- As a follow up it would probably make sense to subclass FusedMoEQuantDesc
or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses
so that only the required quantization parameters are used/stored.
"""
# TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking
_a1: FusedMoEQuantDesc
_a2: FusedMoEQuantDesc
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
def __post_init__(self):
assert (not self.per_act_token_quant
or self.block_shape is None), "illegal quantization"
#
# Convenience accessors for various properties.
#
@property
def quant_dtype(self) -> Union[torch.dtype, str, None]:
return self._a1.dtype
@property
def is_quantized(self) -> bool:
return self.quant_dtype is not None
@property
def is_per_act_token(self) -> bool:
return self.per_act_token_quant
return self._a1.shape == GroupShape.PER_TOKEN
@property
def per_act_token_quant(self) -> bool:
return self._a1.shape == GroupShape.PER_TOKEN
@property
def per_out_ch_quant(self) -> bool:
return self._w1.shape == GroupShape.PER_TOKEN
@property
def is_per_tensor(self) -> bool:
return self._a1.shape == GroupShape.PER_TENSOR
@property
def block_shape(self) -> Optional[list[int]]:
if (self._a1.shape is not None
and self._a1.shape != GroupShape.PER_TENSOR
and self._a1.shape != GroupShape.PER_TOKEN):
return [self._a1.shape.row, self._a1.shape.col]
else:
return None
@property
def is_block_quantized(self) -> bool:
return self.block_shape is not None
@property
def is_per_tensor(self) -> bool:
return not self.per_act_token_quant and self.block_shape is None
def a1_scale(self) -> Optional[torch.Tensor]:
assert self._a1.scale is None or isinstance(self._a1.scale,
torch.Tensor)
return self._a1.scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
return self._a1.alpha_or_gscale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
assert self._a2.scale is None or isinstance(self._a2.scale,
torch.Tensor)
return self._a2.scale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
return self._a2.alpha_or_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
assert self._w1.scale is None or isinstance(self._w1.scale,
torch.Tensor)
return self._w1.scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
return self._w1.zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
return self._w1.bias
@property
def w1_precision(self) -> Optional["PrecisionConfig"]:
assert self._w1.scale is None or isinstance(self._w1.scale,
PrecisionConfig)
return self._w1.scale
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
return self._w1.alpha_or_gscale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
assert self._w2.scale is None or isinstance(self._w2.scale,
torch.Tensor)
return self._w2.scale
@property
def w2_zp(self) -> Optional[torch.Tensor]:
return self._w2.zp
@property
def w2_bias(self) -> Optional[torch.Tensor]:
return self._w2.bias
@property
def w2_precision(self) -> Optional["PrecisionConfig"]:
assert self._w2.scale is None or isinstance(self._w2.scale,
PrecisionConfig)
return self._w2.scale
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
return self._w2.alpha_or_gscale
@property
def use_fp8_w8a8(self) -> bool:
return self.quant_dtype == torch.float8_e4m3fn
@property
def use_int8_w8a8(self) -> bool:
return self.quant_dtype == torch.int8
@property
def use_int8_w8a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == torch.int8)
@property
def use_int4_w4a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == "int4")
@property
def use_mxfp4_w4a4(self) -> bool:
return self.quant_dtype == "mxfp4"
@property
def use_nvfp4_w4a4(self) -> bool:
return self.quant_dtype == "nvfp4"
def config_name(self, dtype: torch.dtype) -> Optional[str]:
"""
Return a string used to construct the filename that contains the
tuning info for a particular quantization scheme. See
try_get_optimal_moe_config in fused_moe.py.
"""
return _get_config_dtype_str(
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=dtype,
)
def scale_shape(
self,
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int]]:
"""
Construct the proper activation scale shape for this
config.
"""
if self.is_quantized:
if self.is_block_quantized:
assert self.block_shape is not None
@ -117,6 +336,10 @@ class FusedMoEQuantConfig:
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int, int]]:
"""
Construct the proper activation batched scale shape for this
config, e.g. (num experts, *scale_shape).
"""
if self.is_quantized:
scale_shape = self.scale_shape(max_tokens, hidden_dim)
assert scale_shape is not None
@ -126,38 +349,218 @@ class FusedMoEQuantConfig:
@staticmethod
def make(
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
quant_dtype: Union[torch.dtype, str, None] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
g1_alphas: Optional[torch.Tensor] = None,
g2_alphas: Optional[torch.Tensor] = None,
a1_gscale: Optional[torch.Tensor] = None,
a2_gscale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
) -> "FusedMoEQuantConfig":
assert sum([
int(flag) for flag in [
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
]
]) <= 1, "Quantization flags are mutually exclusive."
"""
General builder function for a FusedMoEQuantConfig.
- quant_dtype: Optional quantization type. None if activations are
unquantized or quantized prior to calling. Note: "nvfp4" and
"mxfp4" are the only valid string values for quant_dtype.
- per_act_token_quant: Activations have per token quantization.
- per_out_ch_quant: Outputs have per channel quantization. (only
for cutlass).
- block_shape: Optional block size for block-wise quantization.
Incompatible with per_act_token and per_out_ch quant.
- w1_scale: Optional scale to be used for w1.
- w2_scale: Optional scale to be used for w2.
- a1_scale: Optional scale to be used for a1.
- a2_scale: Optional scale to be used for a2.
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
- w1_bias: Optional biases for w1 (GPT OSS Triton).
- w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
"""
assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4"
or quant_dtype == "mxfp4")
a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape)
quant_config = FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
_w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas,
w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas,
w2_zp, w2_bias),
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
assert quant_config.block_shape == block_shape
return quant_config
quant_dtype = get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
)
return FusedMoEQuantConfig(
quant_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
def fp8_w8a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and fp8 weights.
"""
return FusedMoEQuantConfig.make(torch.float8_e4m3fn,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape)
def int8_w8a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for int8 activations and int8 weights.
"""
return FusedMoEQuantConfig.make(
torch.int8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=None,
)
def mxfp4_w4a4_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig.make(
"mxfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=block_shape,
)
def nvfp4_moe_quant_config(
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and nvp4 weights.
"""
return FusedMoEQuantConfig.make(
"nvfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=None,
)
def int4_w4a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int4 weights.
Note: Activations are pre-quantized.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(shape=group_shape),
_a2=FusedMoEQuantDesc(shape=group_shape),
_w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp),
_w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp),
)
def int8_w8a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int8 weights.
Note: Activations are pre-quantized.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(shape=group_shape),
_a2=FusedMoEQuantDesc(shape=group_shape),
_w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp),
_w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp),
)
def biased_moe_quant_config(
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations with biases.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(),
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc(bias=w1_bias),
_w2=FusedMoEQuantDesc(bias=w2_bias),
)
# A FusedMoEQuantConfig constant for an unquantized MoE op.
FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
@dataclass
@ -315,8 +718,6 @@ class FusedMoEConfig:
# The activation type.
in_dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig] = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
@ -328,34 +729,6 @@ class FusedMoEConfig:
assert self.max_num_tokens > 0
@property
def quant_dtype(self) -> Union[torch.dtype, str, None]:
if self.quant_config is not None:
return self.quant_config.quant_dtype
else:
return None
@property
def block_shape(self) -> Optional[list[int]]:
if self.quant_config is not None:
return self.quant_config.block_shape
else:
return None
@property
def per_act_token_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_act_token_quant
else:
return False
@property
def per_out_ch_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_out_ch_quant
else:
return False
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@ -401,97 +774,6 @@ class FusedMoEConfig:
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (self.quant_config is not None
and self.quant_config.quant_dtype == "nvfp4"
and envs.VLLM_USE_FLASHINFER_MOE_FP4
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(
num_experts: int,
experts_per_token: int,
hidden_dim: int,
num_local_experts: int,
moe_parallel_config: FusedMoEParallelConfig,
in_dtype: torch.dtype,
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config: Optional[Union[FusedMoEQuantConfig,
QuantizationConfig]] = None,
has_bias: bool = False,
) -> "FusedMoEConfig":
_quant_config: Optional[FusedMoEQuantConfig] = None
if quant_config is not None and isinstance(quant_config,
QuantizationConfig):
if hasattr(quant_config, 'weight_block_size'):
block_shape = quant_config.weight_block_size
else:
block_shape = None
per_act_token_quant = False
per_out_ch_quant = False
quant_dtype: Union[torch.dtype, str, None] = None
input_quant = get_quant_config_input_quant(quant_config)
weight_quant = get_quant_config_weight_quant(quant_config)
if input_quant is not None:
per_act_token_quant = (input_quant.strategy
== QuantizationStrategy.TOKEN
if input_quant is not None else False)
if input_quant.num_bits == 8:
if input_quant.type == QuantizationType.FLOAT:
quant_dtype = torch.float8_e4m3fn
elif input_quant.type == QuantizationType.INT:
quant_dtype = torch.int8
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Config)
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = "nvfp4"
if weight_quant is not None:
per_out_ch_quant = (
weight_quant.strategy == QuantizationStrategy.CHANNEL)
if quant_dtype is not None:
_quant_config = FusedMoEQuantConfig(
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
)
else:
_quant_config = FusedMoEQuantConfig()
if moe_parallel_config.dp_size > 1:
logger.warning_once("MoE DP setup unable to determine "
"quantization scheme or unsupported "
"quantization type. This model will "
"not run with DP enabled.")
else:
_quant_config = quant_config
return FusedMoEConfig(
num_experts=num_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
num_local_experts=num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=in_dtype,
quant_config=_quant_config,
max_num_tokens=max_num_tokens,
has_bias=has_bias,
)

View File

@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
self.c_strides1, self.c_strides2, workspace13, workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format, topk_weights)
@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
@property
@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
@ -414,16 +395,12 @@ def cutlass_moe_fp8(
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
@ -475,10 +452,18 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
if per_act_token is None:
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.size(0)
assert quant_config is not None
if quant_config.a1_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a2_scale.numel() != 1)
assert (quant_config.w1_scale is None
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
== w1_q.size(1))))
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0)
@ -487,12 +472,11 @@ def cutlass_moe_fp8(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
),
)
@ -502,14 +486,9 @@ def cutlass_moe_fp8(
w2_q,
topk_weights,
topk_ids,
False,
activation,
num_experts,
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -542,7 +521,7 @@ def run_cutlass_moe_fp4(
) -> None:
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
@ -552,16 +531,16 @@ def run_cutlass_moe_fp4(
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts, dtype: int
@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
return
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
use_batched_format: bool = False,
):
super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
super().__init__(quant_config)
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property
def activation_formats(
self
@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
a1q_scale: Optional[torch.Tensor], # unused
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
a=hidden_states,
a1_gscale=self.a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_blockscale=self.w1_scale,
w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_blockscale=self.w2_scale,
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
@ -788,14 +741,9 @@ def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
@ -805,17 +753,31 @@ def cutlass_moe_fp4(
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
per_out_ch_quant=False,
quant_config=quant_config,
use_batched_format=False,
),
)
@ -830,10 +792,6 @@ def cutlass_moe_fp4(
activation="silu",
global_num_experts=e,
expert_map=None,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
return True
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
def run_cutlass_block_scaled_fused_experts(
a: torch.Tensor,
w1: torch.Tensor,

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Optional
import torch
@ -9,9 +8,11 @@ from tqdm import tqdm
import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute,
deepgemm_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
logger = init_logger(__name__)
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
align = deep_gemm_block_shape()[0]
return align <= M and N % align == 0 and K % align == 0
@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=deep_gemm_block_shape(),
))
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
assert quant_config.block_shape == deep_gemm_block_shape()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
@property
def activation_formats(
@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert self.block_shape is not None
assert a1q_scale is not None
assert w1_scale is not None
assert w2_scale is not None
assert self.a2_scale is None
assert self.block_shape is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, K = w1.size()
@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
aq_out=a1q_perm)
assert a1q.size(0) == M_sum
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale),
mm1_out, expert_ids)
self.activation(activation, act_out, mm1_out.view(-1, N))
@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
column_major_scales=True,
out_q=quant_out)
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale),
mm2_out, expert_ids)
if apply_router_weight_on_input:
@ -348,9 +336,16 @@ def deep_gemm_moe_fp8(
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=deep_gemm_block_shape())
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
DeepGemmExperts(),
DeepGemmExperts(quant_config),
)
return fn(
hidden_states,
@ -358,13 +353,9 @@ def deep_gemm_moe_fp8(
w2,
topk_weights,
topk_ids,
inplace,
activation,
global_num_experts,
expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_scale,
quant_config.a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else:
a1q = a1
a1q_scale = None
a1_post_scale = a1_scale
a1_post_scale = quant_config.a1_scale
return (lambda *args: None,
self._do_dispatch(tokens=a1q,
@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids, num_experts,
expert_map,
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()

View File

@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _do_quant(
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None
if self.use_fp8_dispatch:
block_k = quant_config.block_shape[
1] if quant_config.block_shape is not None else None
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
# DeepEP kernels did the quantization for us.
x, x_scales = x
@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
per_act_token_quant,
block_shape)
x, x_scales = moe_kernel_quantize_input(
x, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None:
if quant_config.quant_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128"
has_per_token_scales = a1_scale.numel(
) != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
has_per_token_scales = quant_config.a1_scale.numel(
) != 1 if quant_config.a1_scale is not None else (
quant_config.a2_scale.numel() != 1
if quant_config.a2_scale is not None else False)
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales")
@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=True)
self.handles[a2a_idx] = handle
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config))
return (
hook,
lambda: self._receiver(expert_x, expert_num_tokens, quant_config.
a1_scale, a1.dtype, quant_config))
def _receiver(
self,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
a1_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype,
quant_config)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids,
hook, receiver = self.prepare_async(a1, topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
quant_config)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import Optional
import torch
@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
out_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
quant_config: FusedMoEQuantConfig,
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=quant_dtype,
per_act_token_quant=False,
block_shape=None,
))
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
"Only nvfp4,fp8 quantization are currently supported.")
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
self.out_dtype = out_dtype
@property
@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], # Not used
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
fc2_expert_weights = w2
else:
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not "
"be None for FlashInferExperts")
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
quant_scales = [
self.a1_gscale,
w1_scale.view(torch.int32),
self.w1_scale.view(torch.int32),
self.g1_alphas,
self.a2_gscale,
w2_scale.view(torch.int32),
self.w2_scale.view(torch.int32),
self.g2_alphas,
]
# FlashInfer API requires weight to be long for nvfp4
@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4(
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
a1_gscale=a1_gscale),
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=hidden_states.dtype,
quant_dtype="nvfp4",
quant_config=quant_config,
))
return fused_experts(
@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4(
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
use_dp: bool,
a1_gscale: Optional[torch.Tensor],
num_dispatchers: int = 1,
):
super().__init__()
self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp
self.a1_gscale = a1_gscale
self.local_tokens = None
@property
@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], # Not used
a2_scale: Optional[torch.Tensor], # Not used
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
self.a1_gscale,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,

View File

@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List # noqa: UP035
from typing import Optional
import torch
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
return torch.empty_like(hidden_states)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -8,7 +8,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
from vllm.model_executor.layers.fused_moe.utils import (
@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dtype=torch.float32,
device=a1.device)
else:
assert a1_scale is None
assert quant_config.a1_scale is None
b_a1_scale = None
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
a1_scale = normalize_scales_shape(a1_scale)
a2_scale = normalize_scales_shape(a2_scale)
a1_scale = normalize_scales_shape(quant_config.a1_scale)
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
super().__init__(quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
tmp = _resize_cache(workspace2, (num, N))
if self.quant_config.is_quantized:
assert a1q_scale is not None and w1_scale is not None
assert a1q_scale is not None and self.w1_scale is not None
input = self.dequant(hidden_states[expert, :, :],
a1q_scale[expert])
w1_dq = self.dequant(w1[expert], w1_scale[expert])
w1_dq = self.dequant(w1[expert], self.w1_scale[expert])
input = input[:num] @ w1_dq.transpose(0, 1)
else:
input = hidden_states[expert, :num, :] @ w1[expert].transpose(
@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, tmp, input.to(tmp.dtype))
if self.quant_config.is_quantized:
assert w2_scale is not None
w2_dq = self.dequant(w2[expert], w2_scale[expert])
assert self.w2_scale is not None
w2_dq = self.dequant(w2[expert], self.w2_scale[expert])
else:
w2_dq = w2[expert]
@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
super().__init__(quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
assert max_num_tokens > 0
assert num_dispatchers > 0
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.use_int4_w4a16:
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w1.size(0) == E
assert w2.size(0) == E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config_dtype = self.quant_config.config_name(hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.size(),
@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache2 = _resize_cache(workspace2,
(E, max_num_tokens, N // 2))
if self.use_fp8_w8a8:
# TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8:
intermediate_cache1.fill_(0)
a1q_scale = normalize_batched_scales_shape(a1q_scale, E)
@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a1q_scale,
B_scale=w1_scale,
B_zp=w1_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
B_scale=self.w1_scale,
B_zp=self.w1_zp,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape)
@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache1.view(-1, N))
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2, a2_scale, max_num_tokens, E, N,
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
self.block_shape)
@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a2q_scale,
B_scale=w2_scale,
B_zp=w2_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
B_scale=self.w2_scale,
B_zp=self.w2_zp,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape)

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE kernel."""
"""Fused MoE Triton kernels."""
import functools
import json
import os
# torch.compile needs typing.List. It will fail torch.library.infer_schema
# otherwise
from typing import List # noqa: UP035
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
@ -18,7 +18,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype)
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
_valid_cutlass_block_scaled_grouped_gemm,
run_cutlass_block_scaled_fused_experts)
@ -32,11 +32,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
_resize_cache, activation_without_mul, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
from vllm.platforms import current_platform
@ -1049,87 +1045,66 @@ def fused_grouped_topk(
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
def get_config_dtype_str(
dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False,
use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_mxfp4_w4a4:
return "mxfp4_w4a4"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None:
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, w1_bias, w2_bias)
def inplace_fused_experts_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None:
def inplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
pass
@ -1143,175 +1118,6 @@ direct_register_custom_op(
)
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
pass
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -1319,7 +1125,6 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@ -1341,37 +1146,37 @@ def outplace_fused_experts(
) -> torch.Tensor:
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant,
global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@ -1403,45 +1208,36 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
quant_config: Optional[FusedMoEQuantConfig] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
use_fp8_w8a8 = quant_config.use_fp8_w8a8
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases because they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
if (allow_deep_gemm and use_fp8_w8a8 and
if (allow_deep_gemm and quant_config.use_fp8_w8a8 and
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
assert quant_config is not None
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
@ -1452,22 +1248,23 @@ def fused_experts(hidden_states: torch.Tensor,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(
w1, w2, inplace, activation, apply_router_weight_on_input,
expert_map)):
assert quant_config is not None
return run_cutlass_block_scaled_fused_experts(
a=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids)
else:
@ -1478,26 +1275,49 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
is_act_and_mul=is_act_and_mul,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant,
use_fp8_w8a8=quant_config.use_fp8_w8a8,
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias,
)
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
w1_zp=quant_config.w1_zp,
w2_zp=quant_config.w2_zp,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
block_shape=quant_config.block_shape,
w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias)
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]:
"""
Get the quantization type based on the quantization strategy flags.
We don't have a quant_config at this point so we need to work backwards.
A return type of None means no quantization is required because the
input is unquantized or has been quantized prior to calling
fused_experts_impl.
"""
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
elif use_mxfp4_w4a4:
return "mxfp4"
return None
def fused_experts_impl(
@ -1508,7 +1328,6 @@ def fused_experts_impl(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@ -1557,17 +1376,18 @@ def fused_experts_impl(
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
dtype=hidden_states.dtype)
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4)
config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
dtype=hidden_states.dtype)
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
# quantized prior to calling fused_experts.
quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_mxfp4_w4a4=use_mxfp4_w4a4)
get_config_func = functools.partial(
try_get_optimal_moe_config,
@ -1640,7 +1460,7 @@ def fused_experts_impl(
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
A_scale=a1_scale,
quant_dtype=qtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
@ -1671,30 +1491,29 @@ def fused_experts_impl(
B_bias=w1_bias)
# Activation function with multiplication
if activation == "silu" and is_act_and_mul:
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu" and is_act_and_mul:
elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "swigluoai" and is_act_and_mul:
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
# Activation function without multiplication
elif activation == "silu":
elif activation == SILU_NO_MUL:
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == "gelu":
elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
f"with is_act_and_mul={is_act_and_mul}.")
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
A_scale=a2_scale,
quant_dtype=qtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
@ -1726,164 +1545,13 @@ def fused_experts_impl(
return out_hidden_states
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- is_act_and_mul (bool): If True, use activation-and-mul function for
activation (self-gated activation), otherwise use activation function
for activation (ungated activation).
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and
OCP MXFP4 activation to compute the inner products for w1 and w2.
Defaults to False.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[list[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if not is_act_and_mul:
assert inplace is False, (
"is_act_and_mul=False is not supported with inplace=True")
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
is_act_and_mul=is_act_and_mul,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
super().__init__(quant_config)
@property
def activation_formats(
@ -1929,19 +1597,14 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.use_int4_w4a16:
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
@ -1964,17 +1627,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
if global_num_experts == -1:
global_num_experts = E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
config_dtype,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)
@ -2008,8 +1665,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
self.w1_scale,
self.w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
@ -2018,13 +1675,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
B_bias=self.w1_bias,
)
self.activation(activation, intermediate_cache2,
@ -2033,7 +1690,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, a2_scale, self.quant_dtype,
intermediate_cache2, self.a2_scale, self.quant_dtype,
self.per_act_token_quant, self.block_shape)
invoke_fused_moe_kernel(
@ -2041,8 +1698,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
self.w2_scale,
self.w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
@ -2051,36 +1708,21 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
B_bias=self.w2_bias,
)
ops.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_mxfp4_w4a4: bool,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
TritonExperts(quant_config),
)

View File

@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.utils import has_triton_kernels
@ -23,9 +25,6 @@ if has_triton_kernels():
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible.")
if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
@ -35,20 +34,10 @@ def triton_kernel_moe_forward(
topk: int,
renormalize: bool,
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
routing_data, gather_idx, scatter_idx = routing(gating_output,
@ -64,20 +53,10 @@ def triton_kernel_moe_forward(
gather_idx,
scatter_idx,
activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=w1_precision,
w2_precision=w2_precision,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
expert_map=expert_map)
# This is a triton implementation of the fused_experts function
@ -90,28 +69,23 @@ def triton_kernel_fused_experts(
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
a1q_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert w1_bias is None or w1_bias.dtype == torch.float32
assert w2_bias is None or w2_bias.dtype == torch.float32
assert (quant_config.w1_bias is None
or quant_config.w1_bias.dtype == torch.float32)
assert (quant_config.w2_bias is None
or quant_config.w2_bias.dtype == torch.float32)
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
@ -130,20 +104,20 @@ def triton_kernel_fused_experts(
intermediate_cache1 = matmul_ogs(
hidden_states,
w1,
w1_bias,
quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w1_precision,
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
w2,
w2_bias,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision,
precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
quant_config,
max_num_tokens: int,
num_dispatchers: int,
w1_precision: "PrecisionConfig",
w2_precision: "PrecisionConfig",
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
self.w1_bias = w1_bias
self.w2_bias = w2_bias
@property
def activation_formats(
@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states,
w1,
w2,
None,
None,
None,
routing_data=None,
gather_indx=None,
scatter_indx=None,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
use_fp8_w8a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,
a2_scale=a2_scale)
a1q_scale=a1q_scale)

View File

@ -22,7 +22,8 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, biased_moe_quant_config)
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
@ -78,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum):
class FusedMoEMethodBase(QuantizeMethodBase):
# TODO(bnell): also pass quant_config?
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.fused_experts: Optional[Callable] = None
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
@abstractmethod
@ -103,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@staticmethod
def _maybe_make_prepare_finalize(
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
moe: FusedMoEConfig,
quant_config: Optional[FusedMoEQuantConfig],
) -> Optional[FusedMoEPrepareAndFinalize]:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
# TODO: could allow this now
assert not moe.use_flashinfer_cutlass_kernels, \
"Must be created in modelopt.py"
if moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
moe.quant_dtype,
per_act_token_quant=moe.per_act_token_quant,
block_shape=moe.block_shape,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
@ -165,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
elif moe.use_deepep_ll_kernels:
assert quant_config is not None
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
@ -174,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
use_fp8_dispatch = (moe.quant_config is not None
and moe.quant_config.quant_dtype
== current_platform.fp8_dtype()
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE)
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
@ -192,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return prepare_finalize
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]:
if moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(
self.moe, self.moe_quant_config)
else:
return None
@ -204,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self, layer: torch.nn.Module):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
# We must get the quant config here so that the layer is
# completely initialized, i.e. all weights loaded and post
# processed.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
prepare_finalize = self.maybe_make_prepare_finalize()
if prepare_finalize is not None:
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
@ -213,7 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
experts = self.select_gemm_impl(prepare_finalize, layer)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
@ -223,7 +233,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
@ -232,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize")
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
raise NotImplementedError
@abstractmethod
def apply(
self,
@ -265,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@ -273,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def maybe_make_prepare_finalize(
self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts()
return TritonExperts(self.moe_quant_config)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -303,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.has_bias:
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
@ -320,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.has_bias:
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=params_dtype),
@ -442,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count=logical_replica_count,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
self,
layer: torch.nn.Module,
@ -486,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count=logical_replica_count)
if self.rocm_aiter_moe_enabled:
assert self.fused_experts is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@ -496,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.fused_experts is not None:
if self.has_bias:
if self.moe.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
@ -517,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
@ -933,16 +963,18 @@ class FusedMoE(CustomOp):
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
has_bias=has_bias)
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)
self.moe_config = moe
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
@ -990,6 +1022,9 @@ class FusedMoE(CustomOp):
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
# TODO(bnell): flashinfer uses non-batched format.
# Does it really need a batched buffer?
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.use_flashinfer_cutlass_kernels):
@ -1062,7 +1097,9 @@ class FusedMoE(CustomOp):
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_config.use_flashinfer_cutlass_kernels
return (self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config.use_flashinfer_cutlass_kernels)
def update_expert_map(self):
# ep_size and ep_rank should already be updated
@ -1492,6 +1529,11 @@ class FusedMoE(CustomOp):
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]
def ensure_moe_quant_config(self):
if self.quant_method.moe_quant_config is None:
self.quant_method.moe_quant_config = (
self.quant_method.get_fused_moe_quant_config(self))
@staticmethod
def select_experts(
hidden_states: torch.Tensor,
@ -1711,6 +1753,8 @@ class FusedMoE(CustomOp):
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
self.ensure_moe_quant_config()
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(
@ -1825,14 +1869,17 @@ class FusedMoE(CustomOp):
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.quant_method is not None
self.ensure_moe_quant_config()
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = (
self.dp_size > 1
and self.moe_config.use_flashinfer_cutlass_kernels)
_use_flashinfer_cutlass_kernels = (self.dp_size > 1 and
self.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels):
or _use_flashinfer_cutlass_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (

View File

@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional quantized + dispatched a1_scales.
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
# TODO: add supported activations method (return string)
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
quant_config: FusedMoEQuantConfig,
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
"""
quant_config: Quantization parameters for this experts instance.
"""
self.quant_config = quant_config
@property
@abstractmethod
@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
#
# Various helpers for accessing quantization parameters from the
# quant_config.
#
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
@property
def a1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_scale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_gscale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_scale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_zp
@property
def w2_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_bias
@property
def w2_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_bias
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g1_alphas
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g2_alphas
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights: A map of row to expert weights. Some implementations
choose to do weight application.
choose to do weight application.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
used for a1. Result of quantization from prepare/finalize and not
from the FusedMoEQuantConfig.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module):
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
_chunk_scales(a2_scale, s,
e), topk_ids[s:e], topk_weights[s:e])
return (
a1q[s:e],
_chunk_scales(a1q_scale, s, e),
_chunk_scales(self.fused_experts.a2_scale, s, e),
topk_ids[s:e],
topk_weights[s:e],
)
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, (
@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module):
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module):
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module):
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if quant_config.per_act_token_quant else a1_scale),
a1, (None if quant_config.per_act_token_quant else
quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape)
@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,

View File

@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
a1, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None

View File

@ -7,6 +7,8 @@ from typing import Optional
import torch
from vllm import envs
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk(
def rocm_aiter_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
expert_map: Optional[torch.Tensor] = None,
quant_config: Optional[FusedMoEQuantConfig] = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU)
@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
expert_mask = None
# w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
if (quant_config.per_act_token_quant and apply_router_weight_on_input
and quant_config.use_fp8_w8a8):
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
w2,
topk_weights,
topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_scale=quant_config.w1_scale,
fc2_scale=quant_config.w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
quant_method = QuantMethod.NO.value
# w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8:
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is\
not supported for block scaled moe")
assert w1_scale is not None
assert w2_scale is not None
assert quant_config.w1_scale is not None
assert quant_config.w2_scale is not None
quant_method = QuantMethod.BLOCK_128x128.value
elif use_fp8_w8a8:
elif quant_config.use_fp8_w8a8:
# Currently only per tensor quantization method is enabled.
quant_method = QuantMethod.PER_TENSOR.value
@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
expert_mask=expert_mask,
quant_method=quant_method,
activation_method=activation_method,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input)

View File

@ -7,7 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.triton_expert = TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
super().__init__(quant_config)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and
self.triton_expert = TritonExperts(quant_config)
self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape())
self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None
self.quant_config) if self.allow_deep_gemm else None
@property
def activation_formats(
@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,

View File

@ -5,7 +5,8 @@ from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils import next_power_of_2
@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
w13_bias,
w2_bias,
max_capture_size,
):
super().__init__(moe.quant_config)
super().__init__(quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.w13_bias = w13_bias
self.w2_bias = w2_bias
self.max_capture_size = max_capture_size
@property
@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16).view(torch.int16)
assert w1_scale is not None
assert w2_scale is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
kwargs = {
"topk_ids":
packed_tensor,
@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"gemm1_weights":
w1,
"gemm1_weights_scale":
w1_scale,
self.w1_scale,
"gemm1_bias":
self.w13_bias,
self.w1_bias,
"gemm1_alpha":
self.gemm1_alpha,
"gemm1_beta":
@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"gemm2_weights":
w2,
"gemm2_weights_scale":
w2_scale,
self.w2_scale,
"gemm2_bias":
self.w2_bias,
"output1_scale_scalar":

View File

@ -268,3 +268,7 @@ def _validate_scale_shape(
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"

View File

@ -9,8 +9,10 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,

View File

@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union
import torch
from packaging import version
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
**extra_weight_attrs,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
def _create_weights_4bit(

View File

@ -16,8 +16,11 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
return CompressedTensorsW4A4MoeMethod(layer.moe_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
def __init__(self, moe: FusedMoEConfig):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
super().__init__(moe)
@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.layer = layer
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False)
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer:
return super().maybe_make_prepare_finalize(moe)
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.use_marlin:
return None
elif not self.allow_flashinfer:
return super().maybe_make_prepare_finalize()
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
self.moe)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
"""Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
self.moe,
self.moe_quant_config,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
return nvfp4_moe_quant_config(
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
def apply(
self,
layer: torch.nn.Module,
@ -359,8 +370,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.")
@ -381,7 +390,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype,
)
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin.
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -401,8 +415,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
expert_map=expert_map,
workspace=layer.workspace)
# FlashInfer fused experts path
if self.fused_experts is not None:
elif self.fused_experts is not None:
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
@ -417,11 +430,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# FlashInfer fused experts path
elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
@ -430,51 +442,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")
assert self.moe_quant_config is not None
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype)
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
# TODO(bnell): derive these from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
).to(x.dtype)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
@ -692,16 +699,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
self.fused_experts_func = None
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
if self.use_cutlass:
device = layer.w13_weight.device
@ -722,11 +724,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
device=device,
dtype=torch.int64)
def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.use_marlin or self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
assert self.moe_quant_config is not None
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
@ -740,26 +751,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logger.debug("CutlassBatchedExpertsFp8(%s)",
self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
moe.num_local_experts,
self.moe.num_local_experts,
num_dispatchers,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
)
self.disable_expert_map = (num_dispatchers > 1
@ -774,29 +783,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
)
assert max_num_tokens_per_rank is not None
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN),
quant_config=self.moe_quant_config,
)
else:
return TritonExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN),
)
logger.debug("TritonExperts(%s)", self.__class__.__name__)
return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_channel_quant,
)
def apply(
self,
@ -841,92 +861,19 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype,
)
# cutlass path
if self.use_cutlass:
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
# small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
if self.fused_experts is None:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
per_act_token=per_act_token,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
expert_map=expert_map)
#
# Note: the order here is important. self.fused_experts can override
# cutlass fp8 or fused_experts but not marlin or rocm.
#
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -944,26 +891,95 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map=expert_map,
workspace=layer.workspace)
assert self.fused_experts_func is not None
elif self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
assert self.fused_experts is None
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
return self.fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
elif self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
)
# cutlass path
elif self.use_cutlass:
assert self.moe_quant_config is not None
# small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
quant_config=self.moe_quant_config,
)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_config=self.moe_quant_config,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
@ -1049,6 +1065,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return int8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
def apply(
self,
layer: torch.nn.Module,
@ -1104,14 +1130,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
quant_config=self.moe_quant_config,
)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
@ -1355,6 +1377,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.workspace = marlin_make_workspace_new(device, 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
@ -1588,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
assert self.num_bits == 4 or self.num_bits == 8
config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4
else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)
def apply(
self,
layer: torch.nn.Module,
@ -1641,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size])
quant_config=self.moe_quant_config,
)

View File

@ -8,6 +8,8 @@ import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None)
def apply(
self,
layer: torch.nn.Module,
@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
quant_config=self.moe_quant_config,
)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):

View File

@ -14,9 +14,11 @@ from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -575,20 +577,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe,
layer=self.layer,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -928,10 +916,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv)
def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if (self.rocm_aiter_moe_enabled or self.use_marlin
or self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
prepare_finalize = (
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
@ -940,6 +941,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
assert self.moe_quant_config is not None
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
@ -953,15 +956,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=False,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl(
moe,
self.layer,
self.moe,
self.moe_quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@ -971,11 +972,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.__class__.__name__, self.quant_config.weight_block_size,
False)
return TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.use_marlin:
return None
return fp8_w8a8_moe_quant_config(
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
def apply(
self,
layer: torch.nn.Module,
@ -1005,12 +1020,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and self.fused_experts is None):
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
@ -1066,9 +1083,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logical_replica_count=logical_replica_count,
)
#
# Note: the order of checks is important since self.fused_experts
# can override fused_experts or cutlass but not rocm or marlin.
#
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
assert self.fused_experts is None
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
@ -1076,19 +1098,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
use_fp8_w8a8=True,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
expert_map=expert_map)
expert_map=expert_map,
quant_config=self.moe_quant_config)
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -1105,40 +1121,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert self.block_quant is None
assert (not renormalize and custom_routing_function is not None)
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
common_kwargs = dict(
elif self.fused_experts:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -1149,26 +1133,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert self.block_quant is None
assert (not renormalize and custom_routing_function is not None)
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.fused_experts is not None:
return self.fused_experts(**common_kwargs)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
**common_kwargs,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm),
)
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,

View File

@ -9,8 +9,10 @@ import torch
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,

View File

@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
use_prepack=True,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,

View File

@ -11,7 +11,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.layer import (
@ -294,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
@ -303,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.fused_experts is not None or \
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe,
layer=self.layer,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
# TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
prepare_finalize = (
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
experts = select_cutlass_fp8_gemm_impl(
moe,
self.layer,
self.moe,
self.moe_quant_config,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@ -479,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=False,
)
def apply(
self,
layer: torch.nn.Module,
@ -507,6 +518,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert self.fused_experts is None
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert not renormalize
@ -537,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
indices_type=self.topk_indices_dtype,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
#
# Note: the order here is important. self.fused_experts can override
# cutlass or fused_experts.
#
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not renormalize
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
assert self.moe_quant_config is not None
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
quant_config=self.moe_quant_config,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class ModelOptNvFp4Config(QuantizationConfig):
@ -1034,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" for ModelOptNvFp4FusedMoE.")
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.CUTLASS):
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if (self.use_marlin
or (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM)):
return None
elif (self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = (
build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
))
build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(moe)
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
self.moe,
self.moe_quant_config,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
@ -1360,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_weight = Parameter(layer.w2_weight.data,
requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if (self.use_marlin or self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
return None
return nvfp4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
)
def apply(
self,
layer: torch.nn.Module,
@ -1388,12 +1413,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
assert self.fused_experts is None
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4,
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
@ -1457,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
# trtllm.
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -1477,7 +1510,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
workspace=layer.workspace)
if self.fused_experts is not None:
elif self.fused_experts is not None:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
@ -1485,7 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
out = self.fused_experts(
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -1495,28 +1528,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
assert self.moe_quant_config is not None
out = flashinfer_cutlass_moe_fp4(
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
inplace=False, # TODO(shuw): fix later, now output is high prec
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
@ -1527,23 +1554,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
out = cutlass_moe_fp4(
assert self.moe_quant_config is not None
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
# TODO: derive from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return out
)

View File

@ -6,6 +6,9 @@ from typing import Any, Callable, Optional, Union
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
assert weight_bits == 4 or weight_bits == 8
config_builder = (int4_w4a16_moe_quant_config
if weight_bits == 4 else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
)
def apply(
self,
layer: torch.nn.Module,
@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return fused_experts(
x,
layer.w13_qweight,
@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size])
quant_config=self.moe_quant_config,
)
@staticmethod
def get_weight_loader(layer, weight_loader):

View File

@ -12,6 +12,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
@ -629,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return None
if self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = layer.w13_precision_config
w2_scale = layer.w2_precision_config
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return mxfp4_w4a4_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
@ -647,11 +668,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
"w13_bias": layer.w13_bias,
"w2_bias": layer.w2_bias,
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
return TrtLlmGenExperts(moe, **kwargs)
assert self.moe_quant_config is not None
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
else:
# Use matmul_ogs from triton_kernels here!
raise NotImplementedError(
@ -710,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -941,10 +961,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_precision=self.w13_precision_config,
w2_precision=self.w2_precision_config,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:

View File

@ -11,6 +11,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
mxfp4_w4a4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@ -287,6 +290,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=self.weight_qscheme == "per_channel",
)
def apply(
self,
layer: torch.nn.Module,
@ -339,12 +352,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_qscheme == "per_channel",
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
quant_config=self.moe_quant_config,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
@ -376,14 +384,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_qscheme == "per_channel",
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
quant_config=self.moe_quant_config)
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
@ -487,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return mxfp4_w4a4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
def apply(
self,
layer: torch.nn.Module,
@ -539,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_mxfp4_w4a4=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
activation=activation,
quant_config=self.moe_quant_config,
)
return out

View File

@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase):
fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
assert weight_bits == 4 or weight_bits == 8
config_builder = (int4_w4a16_moe_quant_config
if weight_bits == 4 else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)
def apply(
self,
layer: torch.nn.Module,
@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
ret = fused_experts(
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
block_shape=[0, group_size])
return ret
quant_config=self.moe_quant_config,
)
def rtn_quantize(tensor: torch.Tensor, num_bits: int,

View File

@ -7,7 +7,8 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
a1_gscale: torch.Tensor,
) -> mk.FusedMoEPrepareAndFinalize:
moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
def select_nvfp4_gemm_impl(
moe: FusedMoEConfig,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
moe_quant_config: FusedMoEQuantConfig,
allow_flashinfer: bool,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if allow_flashinfer:
return FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=moe.in_dtype,
quant_dtype="nvfp4",
quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,

View File

@ -8,7 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8(
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_scalar to be initialized")
assert layer.output1_scales_scalar is not None, (
@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
) -> mk.FusedMoEPrepareAndFinalize:
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp, a1_gscale=layer.w13_input_scale)
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
def select_cutlass_fp8_gemm_impl(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
quant_config: FusedMoEQuantConfig,
out_dtype: Optional[torch.dtype] = None,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
"FusedMoE flashinfer kernels are only supported for Llama4"
if moe is not None:
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=moe.in_dtype,
quant_dtype=torch.float8_e4m3fn,
quant_config=quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl(
assert out_dtype is not None, (
"If moe config is None, out_dtype must be passed")
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=out_dtype,
quant_dtype=torch.float8_e4m3fn,
quant_config=quant_config,
)
@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8(
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
assert quant_config is not None
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
layer=layer),
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
select_cutlass_fp8_gemm_impl(moe=None,
layer=layer,
quant_config=quant_config,
out_dtype=hidden_states.dtype))
return fused_experts(

View File

@ -411,6 +411,7 @@ def per_token_group_quant_fp8(
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
# prefer CUDA kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
fp8_min, fp8_max, use_ue8m0)

View File

@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, torch_vllm_outplace_fused_experts)
from vllm.model_executor.layers.fused_moe import (activation_without_mul,
fused_topk)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
@ -230,7 +230,7 @@ class NomicMoE(nn.Module):
self.hidden_size = hidden_size
self.total_intermediate_size = intermediate_size
self.intermediate_size = divide(intermediate_size, self.tp_size)
self.hidden_act = hidden_act
self.hidden_act = activation_without_mul(hidden_act)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@ -297,14 +297,14 @@ class NomicMoE(nn.Module):
router_logits,
self.top_k,
renormalize=False)
final_hidden_states = torch_vllm_outplace_fused_experts(
final_hidden_states = torch.ops.vllm.outplace_fused_experts(
hidden_states=hidden_states,
w1=self.w1,
w2=self.w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=self.hidden_act,
is_act_and_mul=False,
)
if self.tp_size > 1:

View File

@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
topk_weights, topk_ids, _ = fused_topk(
hidden_states,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob)
final_hidden_states = fused_experts(hidden_states,
self.w1,
self.w2,
topk_weights,
topk_ids,
inplace=True)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output

View File

@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True)
topk_weights, topk_ids, _ = fused_topk(hidden_states,
router_logits,
self.top_k,
renormalize=True)
final_hidden_states = fused_experts(hidden_states,
self.ws,
self.w2s,
topk_weights,
topk_ids,
inplace=True)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(

View File

@ -702,4 +702,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
return self.model.get_expert_mapping()

View File

@ -81,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
if not (isinstance(module, FusedMoE)
and module.moe_config.quant_dtype == torch.float8_e4m3fn
and module.moe_config.block_shape == deep_gemm_block_shape()):
if not isinstance(module, FusedMoE):
return False
moe_quant_config = module.quant_method.get_fused_moe_quant_config(module)
if (moe_quant_config is None
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
or moe_quant_config.block_shape != deep_gemm_block_shape()):
return False
if not isinstance(module.quant_method.fused_experts,