diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py index 35c20ee41b9a9..726a2a371d109 100644 --- a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.scalar_type import scalar_types @@ -140,6 +144,12 @@ def bench_run( a_fp8_scale: torch.Tensor, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + for _ in range(num_repeats): fused_experts( a, @@ -147,10 +157,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_fp8_scale, + quant_config=quant_config, ) def run_cutlass_moe_fp4( @@ -172,25 +179,27 @@ def bench_run( device: torch.device, num_repeats: int, ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) for _ in range(num_repeats): with nvtx.annotate("cutlass_moe_fp4", color="green"): cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, - a2_gscale=a2_gs, w1_fp4=w1_fp4, - w1_blockscale=w1_blockscale, - w1_alphas=w1_gs, w2_fp4=w2_fp4, - w2_blockscale=w2_blockscale, - w2_alphas=w2_gs, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, - device=device, + quant_config=quant_config, ) def run_cutlass_from_graph( @@ -211,26 +220,29 @@ def bench_run( e: int, device: torch.device, ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): return cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, w1_fp4=w1_fp4, - w1_blockscale=w1_blockscale, - w1_alphas=w1_alphas, - a2_gscale=a2_gs, w2_fp4=w2_fp4, - w2_blockscale=w2_blockscale, - w2_alphas=w2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, - device=device, + quant_config=quant_config, ) def run_triton_from_graph( @@ -246,16 +258,18 @@ def bench_run( with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) return fused_experts( a, w1, w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_fp8_scale, + quant_config=quant_config, ) def replay_graph(graph, num_repeats): diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index a6b42406b5cb0..14330ae6f03c5 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, @@ -96,6 +97,11 @@ def bench_run( a_scale: torch.Tensor, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) for _ in range(num_repeats): fused_experts( a, @@ -103,10 +109,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, + quant_config=quant_config, ) def run_cutlass_moe( @@ -125,6 +128,12 @@ def bench_run( per_act_token: bool, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + for _ in range(num_repeats): cutlass_moe_fp8( a, @@ -132,14 +141,11 @@ def bench_run( w2, topk_weights, topk_ids, - w1_scale, - w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, - per_act_token, - a1_scale=None, + quant_config=quant_config, ) def run_cutlass_from_graph( @@ -156,6 +162,12 @@ def bench_run( topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): @@ -165,14 +177,11 @@ def bench_run( w2_q, topk_weights, topk_ids, - w1_scale, - w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, - per_act_token, - a1_scale=None, + quant_config=quant_config, ) def run_triton_from_graph( @@ -185,6 +194,11 @@ def bench_run( w2_scale: torch.Tensor, a_scale: torch.Tensor, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): @@ -194,10 +208,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, + quant_config=quant_config, ) def replay_graph(graph, num_repeats): diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 837b2b0c10447..d2beb28f70233 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -14,6 +14,10 @@ import ray import torch from ray.experimental.tqdm_ray import tqdm +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + _get_config_dtype_str, +) from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform from vllm.transformers_utils.config import get_config @@ -134,43 +138,36 @@ def benchmark_config( def run(): from vllm.model_executor.layers.fused_moe import override_config + if use_fp8_w8a8: + quant_dtype = torch.float8_e4m3fn + elif use_int8_w8a16: + quant_dtype = torch.int8 + else: + quant_dtype = None + + quant_config = FusedMoEQuantConfig.make( + quant_dtype=quant_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + ) + with override_config(config): - if use_deep_gemm: - topk_weights, topk_ids, token_expert_indices = fused_topk( - x, input_gating, topk, False - ) - return fused_experts( - x, - w1, - w2, - topk_weights, - topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - allow_deep_gemm=True, - ) - else: - fused_moe( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - ) + topk_weights, topk_ids, token_expert_indices = fused_topk( + x, input_gating, topk, renormalize=not use_deep_gemm + ) + return fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + quant_config=quant_config, + allow_deep_gemm=use_deep_gemm, + ) # JIT compilation & warmup run() @@ -414,7 +411,7 @@ class BenchmarkWorker: use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) - dtype_str = get_config_dtype_str( + dtype_str = _get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which @@ -547,7 +544,7 @@ def save_configs( block_quant_shape: list[int], save_dir: str, ) -> None: - dtype_str = get_config_dtype_str( + dtype_str = _get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index a10666b6ec9a7..b5fcc4cd70bf8 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx -from .mk_objects import (expert_info, make_fused_experts, +from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts, make_prepare_finalize, prepare_finalize_info) from .parallel_utils import ProcessGroupInfo @@ -40,7 +40,7 @@ class Config: E: int topks: Union[list[int], int] dtype: torch.dtype - quant_config: Optional[FusedMoEQuantConfig] + quant_config: Optional[TestMoEQuantConfig] prepare_finalize_type: mk.FusedMoEPrepareAndFinalize fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute @@ -52,7 +52,7 @@ class Config: def __post_init__(self): if self.quant_config is None: - self.quant_config = FusedMoEQuantConfig() + self.quant_config = TestMoEQuantConfig(None, False, False, None) def describe(self) -> str: s = "" @@ -275,21 +275,19 @@ class WeightTensors: or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) def to_current_device(self): - self.w1 = self.w1.to(device=torch.cuda.current_device()) - self.w2 = self.w2.to(device=torch.cuda.current_device()) + device = torch.cuda.current_device() + self.w1 = self.w1.to(device=device) + self.w2 = self.w2.to(device=device) - if self.is_quantized(): - assert self.w1_scale is not None - assert self.w2_scale is not None - self.w1_scale = self.w1_scale.to( - device=torch.cuda.current_device()) - self.w2_scale = self.w2_scale.to( - device=torch.cuda.current_device()) + if self.w1_scale is not None: + self.w1_scale = self.w1_scale.to(device=device) + if self.w2_scale is not None: + self.w2_scale = self.w2_scale.to(device=device) if self.w1_gs is not None: - assert self.w2_gs is not None - self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device()) - self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device()) + self.w1_gs = self.w1_gs.to(device=device) + if self.w2_gs is not None: + self.w2_gs = self.w2_gs.to(device=device) def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": @@ -297,20 +295,12 @@ class WeightTensors: e = s + num_local_experts w1 = self.w1[s:e, :, :] w2 = self.w2[s:e, :, :] - - w1_scale, w2_scale = (None, None) - if self.is_quantized(): - assert self.w1_scale is not None - assert self.w2_scale is not None - w1_scale = self.w1_scale[s:e, :, :] - w2_scale = self.w2_scale[s:e, :, :] - - w1_gs = self.w1_gs - w2_gs = self.w2_gs - if w1_gs is not None: - assert w2_gs is not None - w1_gs = w1_gs[s:e] - w2_gs = w2_gs[s:e] + w1_scale = self.w1_scale[ + s:e, :, :] if self.w1_scale is not None else None + w2_scale = self.w2_scale[ + s:e, :, :] if self.w2_scale is not None else None + w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None + w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs) @@ -323,7 +313,8 @@ class WeightTensors: in_dtype=config.dtype, quant_dtype=config.quant_dtype, block_shape=config.quant_block_shape, - per_act_token_quant=config.is_per_out_ch_quant, + per_out_ch_quant=config. + is_per_act_token_quant, # or config.is_per_out_ch_quant ) return WeightTensors(w1=w1, w2=w2, @@ -342,8 +333,6 @@ class RankTensors: topk_ids: torch.Tensor expert_map: Optional[torch.Tensor] - quant_config: Optional[FusedMoEQuantConfig] - def describe(self): s = "" s += "== Rank Tensors: \n" @@ -426,7 +415,6 @@ class RankTensors: topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, - quant_config=config.quant_config, ) @@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors, and config.supports_apply_weight_on_input()) +def _make_gscale(num_experts: int) -> torch.Tensor: + return torch.ones((num_experts, ), + device=torch.cuda.current_device(), + dtype=torch.float32) + + def make_modular_kernel( config: Config, vllm_config: VllmConfig, - weights: WeightTensors, + quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEModularKernel: def next_power_of_2(x): @@ -548,20 +542,20 @@ def make_modular_kernel( num_local_experts=config.num_local_experts, moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, - quant_config=config.quant_config, max_num_tokens=next_power_of_2(config.M), ) # make modular kernel prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, - config.all2all_backend(), moe) + config.all2all_backend(), moe, + quant_config) fused_experts = make_fused_experts( config.fused_experts_type, moe, + quant_config, prepare_finalize.num_dispatchers(), - weights.w1_gs, - weights.w2_gs, + config.N, ) modular_kernel = mk.FusedMoEModularKernel( @@ -583,12 +577,38 @@ def run_modular_kernel( # weights for rank rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) - mk = make_modular_kernel(config, vllm_config, weights) + if config.quant_dtype == "nvfp4": + gscale = _make_gscale(config.num_local_experts) + else: + gscale = None + + quant_config = FusedMoEQuantConfig.make( + config.quant_dtype, + w1_scale=rank_weights.w1_scale, + w2_scale=rank_weights.w2_scale, + a1_scale=rank_tensors.hidden_states_scale, + g1_alphas=(1 / rank_weights.w1_gs) + if rank_weights.w1_gs is not None else None, + g2_alphas=(1 / rank_weights.w2_gs) + if rank_weights.w2_gs is not None else None, + a1_gscale=gscale, + a2_gscale=gscale, + block_shape=config.quant_block_shape, + per_act_token_quant=config.is_per_act_token_quant, + per_out_ch_quant=config.is_per_out_ch_quant, + ) + + mk = make_modular_kernel(config, vllm_config, quant_config) + + # impls might update the tensor in place + hidden_states = rank_tensors.hidden_states.clone() + + topk_ids = rank_tensors.topk_ids.to( + mk.prepare_finalize.topk_indices_dtype()) mk_kwargs = { "hidden_states": - rank_tensors.hidden_states.clone( - ), # impls might update the tensor in place + hidden_states, "w1": rank_weights.w1, "w2": @@ -596,15 +616,9 @@ def run_modular_kernel( "topk_weights": rank_tensors.topk_weights, "topk_ids": - rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()), + topk_ids, "expert_map": rank_tensors.expert_map, - "w1_scale": - rank_weights.w1_scale, - "w2_scale": - rank_weights.w2_scale, - "a1_scale": - rank_tensors.hidden_states_scale, "global_num_experts": config.E, "apply_router_weight_on_input": diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py index 5dbfdfc153f9f..c1037b60bf383 100644 --- a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -10,7 +10,8 @@ import torch from tqdm import tqdm from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG) from vllm.platforms import current_platform from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, @@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str): quant_config_dict = config_dict['quant_config'] del config_dict['quant_config'] if quant_config_dict is None: - quant_config = FusedMoEQuantConfig(None) + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config_dict = asdict(quant_config) config_dict |= quant_config_dict diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index aecffae36ae5e..7947391d03483 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +@dataclass +class TestMoEQuantConfig: + quant_dtype: Union[torch.dtype, str, None] + per_out_ch_quant: bool + per_act_token_quant: bool + block_shape: Optional[list[int]] + + @dataclass class PrepareFinalizeInfo: activation_format: mk.FusedMoEActivationFormat @@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [ torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 ] common_float_and_int_types = common_float_types + [torch.int8] -nv_fp4_types = ["nvfp4"] +nvfp4_types = ["nvfp4"] fp8_types = [torch.float8_e4m3fn] @@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe() register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, standard_format, - nv_fp4_types, + nvfp4_types, blocked_quantization_support=True, backend=None, force_multigpu=True, @@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe() register_experts( FlashInferExperts, standard_format, - nv_fp4_types, + nvfp4_types, blocked_quantization_support=True, supports_chunking=True, # Note: this is a hack to get it to run for now @@ -306,39 +314,39 @@ if cutlass_fp4_supported(): register_experts( CutlassExpertsFp4, standard_format, - nv_fp4_types, + nvfp4_types, blocked_quantization_support=True, supports_chunking=True, supports_expert_map=False, ) -MK_QUANT_CONFIGS = [ +MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [ None, # per-channel / per-column weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=False, + block_shape=None), # per-channel / per-column weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=True, + block_shape=None), # per-tensor weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), # per-tensor weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=True, + block_shape=None), # block-quantized weights and 128 block per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=[128, 128]), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=[128, 128]), # TODO (varun) : Should we test the following combinations ? # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations @@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): MK_QUANT_CONFIGS += [ - FusedMoEQuantConfig(quant_dtype="nvfp4", - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig(quant_dtype="nvfp4", + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), ] -def _make_gscale(num_experts: int) -> torch.Tensor: - return torch.ones((num_experts, ), - device=torch.cuda.current_device(), - dtype=torch.float32) - - def make_prepare_finalize( prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, backend: Optional[str], moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: - prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize( + moe, quant_config) assert prepare_finalize is not None return prepare_finalize elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: return FlashInferCutlassMoEPrepareAndFinalize( - use_dp=moe.moe_parallel_config.dp_size > 1, - a1_gscale=_make_gscale(moe.num_local_experts), - ) + use_dp=moe.moe_parallel_config.dp_size > 1) else: return MoEPrepareAndFinalizeNoEP() @@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: return t[s:e] +def make_cutlass_strides( + e: int, + n: int, + k: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + return ab_strides1, ab_strides2, c_strides1, c_strides2 + + def make_fused_experts( fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, num_dispatchers: int, - w1_gs: Optional[torch.Tensor], - w2_gs: Optional[torch.Tensor], + N: int, ) -> mk.FusedMoEPermuteExpertsUnpermute: - use_fp8 = moe.quant_dtype == torch.float8_e4m3fn batch_kwargs = { "max_num_tokens": moe.max_num_tokens, "num_dispatchers": num_dispatchers, } quant_kwargs = { - "use_fp8_w8a8": use_fp8, - "use_int8_w8a8": False, - "use_int8_w8a16": False, - "use_int4_w4a16": False, - "block_shape": moe.block_shape, - "per_act_token_quant": moe.per_act_token_quant, + "quant_config": quant_config, } deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000) + if fused_experts_type == BatchedDeepGemmExperts: - kwargs = batch_kwargs | { - "block_shape": moe.block_shape, - "per_act_token_quant": moe.per_act_token_quant, - } + kwargs = batch_kwargs | quant_kwargs print(f"Making BatchedDeepGemmExperts {kwargs} ...") experts = BatchedDeepGemmExperts(**kwargs) elif fused_experts_type == BatchedTritonExperts: @@ -422,8 +429,8 @@ def make_fused_experts( print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts () ...") - experts = DeepGemmExperts() + print("Making DeepGemmExperts {quant_config} ...") + experts = DeepGemmExperts(quant_config) elif fused_experts_type == TritonExperts: kwargs = quant_kwargs print(f"Making TritonExperts {kwargs} ...") @@ -437,62 +444,50 @@ def make_fused_experts( print(f"Making NaiveBatchedExperts {kwargs} ...") experts = NaiveBatchedExperts(**kwargs) elif fused_experts_type == CutlassExpertsFp8: + strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, - } + "ab_strides1": strides[0], + "ab_strides2": strides[1], + "c_strides1": strides[2], + "c_strides2": strides[3], + } | quant_kwargs print(f"Making CutlassExpertsFp8 {kwargs} ...") experts = CutlassExpertsFp8(**kwargs) elif fused_experts_type == CutlassBatchedExpertsFp8: + strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { "max_experts_per_worker": moe.num_local_experts, "num_dispatchers": num_dispatchers, "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, - } + "ab_strides1": strides[0], + "ab_strides2": strides[1], + "c_strides1": strides[2], + "c_strides2": strides[3], + } | quant_kwargs print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") experts = CutlassBatchedExpertsFp8(**kwargs) elif fused_experts_type == CutlassExpertsFp4: - assert w1_gs is not None and w2_gs is not None - num_experts = moe.num_local_experts - rank = moe.moe_parallel_config.dp_rank kwargs = { - "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), - "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), - "a1_gscale": _make_gscale(num_experts), - "a2_gscale": _make_gscale(num_experts), - "max_experts_per_worker": num_experts, - "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, + "max_experts_per_worker": moe.num_local_experts, "num_dispatchers": num_dispatchers, - } + "out_dtype": moe.in_dtype, + } | quant_kwargs print(f"Making CutlassExpertsFp4 {kwargs} ...") experts = CutlassExpertsFp4(**kwargs) elif fused_experts_type == FlashInferExperts: - assert w1_gs is not None and w2_gs is not None - num_experts = moe.num_local_experts - rank = moe.moe_parallel_config.dp_rank kwargs = { - "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), - "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), - "a1_gscale": _make_gscale(num_experts), - "a2_gscale": _make_gscale(num_experts), "out_dtype": moe.in_dtype, - "quant_dtype": "nvfp4", "ep_rank": moe.ep_rank, "ep_size": moe.ep_size, "tp_rank": moe.tp_rank, "tp_size": moe.tp_size, - } + } | quant_kwargs print(f"Making FlashInferExperts {kwargs} ...") experts = FlashInferExperts(**kwargs) else: raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") + torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80) + return experts diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 018d4c224f75e..afec97e8cffd0 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -6,6 +6,8 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, rank=0, ) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_s, + w2_scale=w2_s, + per_act_token_quant=False, + block_shape=BLOCK_SIZE, + ) + # triton (reference) triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=True, - per_act_token_quant=False, - block_shape=BLOCK_SIZE, + quant_config=quant_config, ) mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) @@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - w1_scale=w1_s, - w2_scale=w2_s, global_num_experts=E, ) @@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, deepgemm_experts = BatchedDeepGemmExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - block_shape=BLOCK_SIZE, - per_act_token_quant=False, + quant_config=quant_config, ) mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) @@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - w1_scale=w1_s, - w2_scale=w2_s, global_num_experts=E, ) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 00b2d780e66f5..7e79828937c77 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, in_dtype=act_dtype, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) out_shape = (num_experts, max_tokens_per_expert, N) @@ -250,7 +250,7 @@ def test_fused_moe_batched_experts( block_shape=block_shape, in_dtype=act_dtype, quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) if input_scales and quant_dtype is not None: diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index ecc57acc67963..da383e18c3721 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -4,7 +4,7 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights +from tests.kernels.moe.utils import make_test_quant_config, make_test_weights from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, native_w8a8_block_matmul) from vllm.config import VllmConfig, set_current_vllm_config @@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + w1, w2, quant_config = make_test_quant_config( + E, + N, + K, + dtype, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size, + ) - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - use_mxfp4_w4a4=False, - per_act_token_quant=False, - block_shape=block_size) + m_fused_moe = modular_triton_fused_moe(quant_config) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) @@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a, w1, w2, - w1_s, - w2_s, + quant_config.w1_scale, + quant_config.w2_scale, topk_weights, topk_ids, block_size, ) - out = fused_experts( - a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) + out = fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config) - m_out = m_fused_moe( - a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - ) + m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids) - # 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0] - tol = 0.035 if M < 40000 else 0.039 + # 0.039 only needed for M >= 8192 + tol = 0.035 if M < 8192 else 0.039 torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol) torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol) @@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( + E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_out_ch_quant=False, + block_shape=block_size, + ) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 5e4a93963f8e8..041a13ca5585a 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -4,12 +4,12 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights +from tests.kernels.moe.utils import make_test_quant_config from tests.kernels.quant_utils import (native_per_token_group_quant_int8, native_w8a8_block_matmul) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): @@ -50,7 +50,7 @@ MNK_FACTORS = [ (2048, 128, 128), (2048, 1024, 7168), (2048, 4096, 512), - (2048, 4096, 7168), + (2048, 4096, 4096), ] E = [8, 24] @@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.int8, - per_act_token_quant=False, - block_shape=block_size) + w1, w2, quant_config = make_test_quant_config( + E, + N, + K, + dtype, + quant_dtype=torch.int8, + per_act_token_quant=False, + block_shape=block_size, + ) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, + out = fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config) + ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale, + quant_config.w2_scale, score, topk, block_size) # Check results diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index c84f66383b902..ca6be767dab39 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import dataclasses from math import prod from typing import Optional @@ -9,6 +10,8 @@ import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp8, run_cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, @@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, def slice_experts(): slice_params = [ "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", - "c_strides2", "w1_scale", "w2_scale" + "c_strides2" ] full_tensors = { k: v @@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, if k in slice_params and k in cutlass_moe_kwargs } + quant_config = cutlass_moe_kwargs["quant_config"] + for i in range(0, num_experts, num_local_experts): s, e = i, i + num_local_experts @@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, for k, t in full_tensors.items(): cutlass_moe_kwargs[k] = t[s:e] + new_quant_config = copy.deepcopy(quant_config) + new_quant_config._w1.scale = quant_config.w1_scale[s:e] + new_quant_config._w2.scale = quant_config.w2_scale[s:e] + + cutlass_moe_kwargs["quant_config"] = new_quant_config + yield cutlass_moe_kwargs out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) @@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, + per_out_ch: bool, num_local_experts: Optional[int] = None) -> torch.Tensor: assert not any([ t is None for t in [ @@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ] ]) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=moe_tensors.w1_scale, + w2_scale=moe_tensors.w2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + # Set to moe_tensors.a_scale iff static scales + per tensor. + # This is not currently being tested. + a1_scale=None, + ) + kwargs = { 'a': moe_tensors.a, 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] 'topk_weights': topk_weights, 'topk_ids': topk_ids, - 'w1_scale': moe_tensors.w1_scale, - 'w2_scale': moe_tensors.w2_scale, 'ab_strides1': moe_tensors.ab_strides1, 'ab_strides2': moe_tensors.ab_strides2, 'c_strides1': moe_tensors.c_strides1, 'c_strides2': moe_tensors.c_strides2, - 'per_act_token': per_act_token, - 'a1_scale': None #moe_tensors.a_scale + 'quant_config': quant_config, } num_experts = moe_tensors.w1.size(0) @@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph( # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + triton_output = fused_experts(mt.a_d, + mt.w1_d, + mt.w2_d, + topk_weights, + topk_ids, + quant_config=quant_config) if ep_size is not None: assert e % ep_size == 0, "Cannot distribute experts evenly" number_local_experts = e // ep_size else: number_local_experts = None + cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, - number_local_experts) + per_out_ch, number_local_experts) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. @@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph( # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + triton_output = fused_experts(mt.a_d, + mt.w1_d, + mt.w2_d, + topk_weights, + topk_ids, + quant_config=quant_config) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): cutlass_output = run_8_bit(mt, topk_weights, topk_ids, - per_act_token) + per_act_token, per_out_ch) torch.cuda.synchronize() graph.replay() diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 6558cab6a9eff..ced5457d4f53b 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -71,9 +73,12 @@ def make_block_quant_fp8_weights( Return weights w1q, w2q, w1_scale, w2_scale """ (_, w1q, w1_scale, _), (_, w2q, w2_scale, - _) = make_test_weights(e, n, k, torch.bfloat16, + _) = make_test_weights(e, + n, + k, + torch.bfloat16, torch.float8_e4m3fn, - block_size) + block_shape=block_size) return w1q, w2q, w1_scale, w2_scale @@ -130,10 +135,11 @@ class TestTensors: config=config) -def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - max_tokens_per_rank: int, dp_size: int, - hidden_size: int, q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: +def make_ll_modular_kernel( + pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int, + dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: assert test_config.low_latency assert test_config.use_fp8_dispatch is not None @@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, fused_experts = BatchedDeepGemmExperts( max_num_tokens=max_tokens_per_rank, num_dispatchers=pgi.world_size // dp_size, - block_shape=test_config.block_size, - per_act_token_quant=test_config.per_act_token_quant) + quant_config=quant_config, + ) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, num_local_experts: int, - q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: +def make_ht_modular_kernel( + pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, + num_local_experts: int, q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: assert not test_config.low_latency assert test_config.use_fp8_dispatch is None @@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, q_dtype=q_dtype, block_shape=test_config.block_size) - fused_experts = DeepGemmExperts() + fused_experts = DeepGemmExperts(quant_config) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - num_local_experts: int, - test_tensors: TestTensors) -> FusedMoEModularKernel: +def make_modular_kernel( + pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, + num_local_experts: int, test_tensors: TestTensors, + quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: q_dtype = torch.float8_e4m3fn test_config = test_tensors.config @@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, dp_size=dp_size, hidden_size=hidden_size, q_dtype=q_dtype, - test_config=test_config) + test_config=test_config, + quant_config=quant_config) else: - mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts, - q_dtype, test_config) + mk = make_ht_modular_kernel(pg, + pgi, + dp_size, + num_local_experts, + q_dtype, + test_config, + quant_config=quant_config) return mk @@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + # Low-Latency kernels can't dispatch scales. + a1_scale=(None if test_config.low_latency else + test_tensors.rank_token_scales), + block_shape=test_config.block_size, + ) + # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel( pg=pg, pgi=pgi, dp_size=dp_size, num_local_experts=num_local_experts, - test_tensors=test_tensors) - - # Low-Latency kernels can't dispatch scales. - a1_scale = (None - if test_config.low_latency else test_tensors.rank_token_scales) + test_tensors=test_tensors, + quant_config=quant_config) out = mk.forward(hidden_states=test_tensors.rank_tokens, w1=w1, @@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, activation="silu", global_num_experts=num_experts, expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=a1_scale, - a2_scale=None, apply_router_weight_on_input=False) return out @@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, a1_scale: torch.Tensor, block_shape: list[int]): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + block_shape=block_shape, + ) + return fused_experts( hidden_states=a, w1=w1, @@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - block_shape=block_shape, + quant_config=quant_config, # Make sure this is set to False so we # don't end up comparing the same implementation. allow_deep_gemm=False) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 6a53af68cd53a..54d3a62b03fcc 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -15,6 +15,7 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -129,11 +130,9 @@ def make_modular_kernel( num_local_experts: int, q_dtype: Optional[torch.dtype], use_fp8_dispatch: bool, - per_act_token_quant: bool, + quant_config: FusedMoEQuantConfig, ) -> FusedMoEModularKernel: - is_quantized = q_dtype is not None - ht_args: Optional[DeepEPHTArgs] = None ll_args: Optional[DeepEPLLArgs] = None @@ -159,24 +158,14 @@ def make_modular_kernel( num_dispatchers = pgi.world_size // dp_size if low_latency_mode: - assert not per_act_token_quant, "not supported in ll mode" + assert not quant_config.per_act_token_quant, "not supported in ll mode" fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, num_dispatchers=num_dispatchers, - use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_act_token_quant=False, + quant_config=quant_config, ) else: - fused_experts = TritonExperts( - use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_act_token_quant=per_act_token_quant, - ) + fused_experts = TritonExperts(quant_config=quant_config) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) @@ -217,11 +206,6 @@ def deep_ep_moe_impl( if is_quantized: q_dtype = torch.float8_e4m3fn - # Make modular kernel - mk: FusedMoEModularKernel = make_modular_kernel( - pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, - num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant) - out_hidden_states = torch.empty_like(test_tensors.rank_tokens) total_num_tokens = test_tensors.rank_tokens.size(0) @@ -236,6 +220,19 @@ def deep_ep_moe_impl( rank_token_scales_chunk = rank_token_scales_chunk[ chunk_start:chunk_end] + quant_config = FusedMoEQuantConfig.make( + q_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token_quant, + a1_scale=rank_token_scales_chunk, + ) + + # Make modular kernel + mk: FusedMoEModularKernel = make_modular_kernel( + pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, + num_local_experts, q_dtype, use_fp8_dispatch, quant_config) + out = mk.forward(hidden_states=rank_tokens_chunk, w1=w1, w2=w2, @@ -245,12 +242,6 @@ def deep_ep_moe_impl( activation="silu", global_num_experts=num_experts, expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=rank_token_scales_chunk, - a2_scale=None, apply_router_weight_on_input=False) if not skip_result_store: @@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("m,n,k", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @requires_deep_ep def test_deep_ep_moe( dtype: torch.dtype, - mnk: tuple[int, int, int], + m: int, + n: int, + k: int, num_experts: int, topk: int, world_dp_size: tuple[int, int], @@ -424,7 +417,6 @@ def test_deep_ep_moe( ): low_latency_mode = False use_fp8_dispatch = False - m, n, k = mnk current_platform.seed_everything(7) world_size, dp_size = world_dp_size @@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False] @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("m,n,k", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @multi_gpu_test(num_gpus=2) @requires_deep_ep -def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], - num_experts: int, topk: int, - world_dp_size: tuple[int, int], - use_fp8_dispatch: bool): - +def test_low_latency_deep_ep_moe( + dtype: torch.dtype, + m: int, + n: int, + k: int, + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], + use_fp8_dispatch: bool, +): low_latency_mode = True - m, n, k = mnk if (low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 4472f34a6291a..d575b6d4ca62c 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -11,6 +11,8 @@ import math import pytest import torch +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) # vLLM fused-expert reference (Triton fallback + DeepGEMM option) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + ) + # triton reference out_triton = fused_experts( hidden_states=tokens_bf16, @@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - block_shape=block_size, + quant_config=quant_config, allow_deep_gemm=False, ) @@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - block_shape=block_size, + quant_config=quant_config, allow_deep_gemm=True, ) diff = calc_diff(out_deepgemm, out_triton) assert diff < 0.001, f"Diff exceeded 1%: {diff}" -# Note: W1 has shape (E, 2N, K), so N = 512 -# can trigger the deepgemm path. +# Note: N <= 512 will disable the deepgemm path due to performance issues. MNKs = [ (1024, 768, 128), (1024, 768, 512), @@ -144,15 +144,15 @@ TOPKS = [2, 6] NUM_EXPERTS = [32] -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize(("m", "n", "k"), MNKs) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") -def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): +def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_DEEP_GEMM", "1") + with monkeypatch.context() as mp: + mp.setenv("VLLM_USE_DEEP_GEMM", "1") _fused_moe_mod = importlib.import_module( "vllm.model_executor.layers.fused_moe.fused_moe") @@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8) - m, n, k = mnk - if topk > num_experts: pytest.skip(f"topk={topk} > num_experts={num_experts}") diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 52a3d2ca3b422..5564db3cda0e3 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -6,6 +6,8 @@ import pytest import torch from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( custom_routing_function=Llama4MoE.custom_routing_function, scoring_func="softmax") + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + per_act_token_quant=False, + ) + output = fused_experts( td.hidden_states, td.w13_quantized, @@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( topk_ids=topk_ids, inplace=False, activation="silu", - use_fp8_w8a8=True, - per_channel_quant=False, global_num_experts=e, expert_map=None, - w1_scale=td.w13_weight_scale, - w2_scale=td.w2_weight_scale, - a1_scale=td.a1_scale, - a2_scale=td.a2_scale, apply_router_weight_on_input=True, + quant_config=quant_config, ) flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( @@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( custom_routing_function=Llama4MoE.custom_routing_function, scoring_func="softmax") + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + per_act_token_quant=False, + ) + output = fused_experts( td.hidden_states, td.w13_quantized, @@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk_ids=topk_ids, inplace=False, activation="silu", - use_fp8_w8a8=True, - per_channel_quant=False, global_num_experts=e, expert_map=None, - w1_scale=td.w13_weight_scale, - w2_scale=td.w2_weight_scale, - a1_scale=td.a1_scale, - a2_scale=td.a2_scale, apply_router_weight_on_input=True, + quant_config=quant_config, ) td.layer.dp_size = 1 diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1c14df2b914aa..8bf096b798cb8 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -3,7 +3,7 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights +from tests.kernels.moe.utils import make_test_quant_config from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype) @@ -41,7 +41,6 @@ MNK_FACTORS = [ @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", [40, 64, 256]) -#@pytest.mark.parametrize("e", [128, 256]) @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() @@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, quant_blocksize = 16 - (_, w1_q, w1_blockscale, - w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( - e, - n, - k, - in_dtype=dtype, - quant_dtype="nvfp4", - block_shape=None, # use quant_blocksize? - per_act_token_quant=False, - ) + w1_q, w2_q, quant_config = make_test_quant_config( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, + per_act_token_quant=False, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, @@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, topk, renormalize=False) - a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) - assert w1_gs is not None - assert w2_gs is not None - assert w1_blockscale is not None - assert w2_blockscale is not None - flashinfer_experts = FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - FlashInferExperts( - a1_gscale=a1_gs, - g1_alphas=(1 / w1_gs), - a2_gscale=a2_gs, - g2_alphas=(1 / w2_gs), - out_dtype=dtype, - quant_dtype="nvfp4", - )) + FlashInferExperts(out_dtype=dtype, quant_config=quant_config), + ) flashinfer_output = flashinfer_experts( hidden_states=a, w1=w1_q, - w1_scale=w1_blockscale, w2=w2_q, - w2_scale=w2_blockscale, - a1_scale=a1_gs, - a2_scale=a2_gs, topk_weights=topk_weights, topk_ids=topk_ids, ) @@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=dtype, - device=w1_q.device, - block_size=quant_blocksize) - w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=dtype, - device=w2_q.device, - block_size=quant_blocksize) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]), + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]), + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 54f2351bf6d9b..024993c7677dd 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): pc2, ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) + quant_config = FusedMoEQuantConfig.make( + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2, + ) + out_triton_monolithic = triton_kernel_moe_forward( hidden_states=x_tri, w1=w1_tri, @@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): gating_output=exp_data_tri, topk=topk, renormalize=True, - w1_bias=w1_bias_tri, - w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2, + quant_config=quant_config, ) out_triton_monolithic = out_triton_monolithic[..., :K] @@ -336,6 +341,13 @@ def batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + w1_precision=w1_precision, + w2_precision=w2_precision, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) + fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize( max_num_tokens, @@ -344,19 +356,12 @@ def batched_moe( rank=0, ), BatchedOAITritonExperts( - None, max_num_tokens=max_num_tokens, num_dispatchers=1, - w1_precision=w1_precision, - w2_precision=w2_precision, + quant_config=quant_config, ), ) - extra_expert_args = { - "w1_bias": w1_bias, - "w2_bias": w2_bias, - } - topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) return fused_experts( @@ -365,7 +370,6 @@ def batched_moe( w2, topk_weight, topk_ids, - extra_expert_args=extra_expert_args, ) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 6112183be5475..19c4301bd23d5 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -12,7 +12,6 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, current_platform, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, run_modular_kernel) from .modular_kernel_tools.mk_objects import ( MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info) + MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig, + expert_info) from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, parallel_launch_with_config) @@ -55,7 +55,7 @@ def rank_worker( pgi: ProcessGroupInfo, vllm_config: VllmConfig, cpu_group, - config: Config, + base_config: Config, weights: WeightTensors, verbose: bool, ): @@ -63,42 +63,44 @@ def rank_worker( # sanity check from vllm import envs - if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + if base_config.fused_moe_chunk_size is not None: + assert ( + base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) # get weights to this device weights.to_current_device() - Ms = config.Ms + Ms = base_config.Ms assert isinstance(Ms, list) - TOPKs = config.topks + TOPKs = base_config.topks assert isinstance(TOPKs, list) exceptions = [] count = 0 for m, topk in product(Ms, TOPKs): + # override m and topk + config = copy.deepcopy(base_config) + config.Ms = m + config.topks = topk + try: print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...") count = count + 1 - # override m and topk - cfgx = copy.deepcopy(config) - cfgx.Ms = m - cfgx.topks = topk # inputs for rank - rank_tensors = RankTensors.make(cfgx, pgi) + rank_tensors = RankTensors.make(config, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors) with set_current_vllm_config(vllm_config): - ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + ref_out = reference_moe_impl(config, weights, rank_tensors) if config.quant_dtype == "nvfp4": - atol = 1e-1 - rtol = 1e-1 + atol = 1e-1 if config.K < 4096 else 2e-1 + rtol = 1e-1 if config.K < 4096 else 2e-1 else: atol = 3e-2 rtol = 3e-2 @@ -132,7 +134,7 @@ Ms = [32, 64] # hidden sizes, making this too large will cause fp4 tests to fail. # Also needs to be a multiple of 1024 for deep_gemm. Ks = [2048] -Ns = [2048] +Ns = [1024] TOPKs = [4, 1] Es = [32] DTYPEs = [torch.bfloat16] @@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool: @meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[FusedMoEQuantConfig], + quant_config: Optional[TestMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): @@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu( @pytest.mark.parametrize("world_size", [1]) def test_modular_kernel_combinations_singlegpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[FusedMoEQuantConfig], + quant_config: Optional[TestMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 850c486b95240..00835bec9a15c 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -15,11 +15,14 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( @@ -187,14 +190,9 @@ def test_fused_moe( # # Setup test functions # + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - use_mxfp4_w4a4=False, - per_act_token_quant=False, - block_shape=None) + m_fused_moe_fn = modular_triton_fused_moe(quant_config) def m_fused_moe( a: torch.Tensor, @@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None + if weight_bits == 4: + quant_config_builder = int4_w4a16_moe_quant_config + else: + assert weight_bits == 8 + quant_config_builder = int8_w8a16_moe_quant_config + + quant_config = quant_config_builder(w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + with set_current_vllm_config(vllm_config): triton_output = fused_moe(a, w1_qweight, @@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, score, topk, renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, global_num_experts=e, expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) + quant_config=quant_config) torch_output = torch_moe(a, w1_ref, w2_ref, diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 30388ef9375d4..a48bfeb10b2e6 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform @@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, in_dtype=dtype, quant_dtype="nvfp4", block_shape=None, # use quant_blocksize? - per_act_token_quant=False, + per_out_ch_quant=False, ) score = torch.randn((m, e), device="cuda", dtype=dtype) @@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, assert w1_blockscale is not None assert w2_blockscale is not None + quant_config = nvfp4_moe_quant_config( + g1_alphas=(1 / w1_gs), + g2_alphas=(1 / w2_gs), + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + ) + cutlass_output = cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, w1_fp4=w1_q, - w1_blockscale=w1_blockscale, - g1_alphas=(1 / w1_gs), - a2_gscale=a2_gs, w2_fp4=w2_q, - w2_blockscale=w2_blockscale, - g2_alphas=(1 / w2_gs), topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=quant_config, m=m, n=n, k=k, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 9e78f4d6e4da0..59126cef6adbb 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,6 +9,8 @@ import torch from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassBatchedExpertsFp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -143,10 +145,16 @@ def pplx_cutlass_moe( device="cuda", dtype=torch.int64) - experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, - out_dtype, per_act_token, per_out_ch, - ab_strides1, ab_strides2, c_strides1, - c_strides2) + experts = CutlassBatchedExpertsFp8( + num_local_experts, num_dispatchers, out_dtype, ab_strides1, + ab_strides2, c_strides1, c_strides2, + fp8_w8a8_moe_quant_config( + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + w1_scale=chunk_by_rank(w1_scale, rank, world_size), + w2_scale=chunk_by_rank(w2_scale, rank, world_size), + a1_scale=chunk_by_rank(a1_scale, rank, world_size) + if per_act_token else a1_scale[rank])) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -167,10 +175,7 @@ def pplx_cutlass_moe( chunk_topk_ids, global_num_experts=num_experts, expert_map=None, #TODO - w1_scale=chunk_by_rank(w1_scale, rank, world_size), - w2_scale=chunk_by_rank(w2_scale, rank, world_size), - a1_scale=chunk_by_rank(a1_scale, rank, world_size) - if per_act_token else a1_scale[rank]) + ) torch.cuda.synchronize() diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 394f521140859..4ca4a1e79c57c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [ ] PPLX_COMBOS = [ - # TODO: figure out why this fails, seems to be test problem + # TODO(bnell): figure out why this fails, seems to be test problem #(1, 128, 128), (2, 128, 512), (3, 1024, 2048), @@ -360,18 +360,18 @@ def pplx_prepare_finalize( b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( a_chunk, - a1_scale, - a2_scale, chunk_topk_weight, chunk_topk_ids, num_experts, None, False, - FusedMoEQuantConfig( + FusedMoEQuantConfig.make( quant_dtype, - per_act_token_quant, - False, - block_shape, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=False, + block_shape=block_shape, + a1_scale=a1_scale, + a2_scale=a2_scale, ), ) @@ -540,20 +540,6 @@ def pplx_moe( topk_ids = topk_ids.to(dtype=torch.uint32) - experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - ) - - fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - shared_experts, - ) - # Note: workers with the same dp_rank must use the exact same inputs. a_chunk = chunk_by_rank(a, rank, world_size) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size) @@ -567,6 +553,28 @@ def pplx_moe( a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size) a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, + a1_scale=a1_scale_chunk, + a2_scale=a2_scale_chunk, + ) + + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=quant_config, + ) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts, + ) + # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. @@ -585,10 +593,6 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, global_num_experts=num_experts) if use_cudagraphs: @@ -605,10 +609,6 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, global_num_experts=num_experts) torch.cuda.synchronize() @@ -820,7 +820,7 @@ def test_pplx_moe_slow( k, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e, @@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, k, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) args["w1"] = w1 args["w2"] = w2 diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index dfd0f35c8da3d..1c31464b30e7f 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,10 +7,12 @@ import itertools import pytest import torch +from tests.kernels.moe.utils import fused_moe from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.platforms import current_platform if current_platform.get_device_capability() < (9, 0): @@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): score, topk, renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization + quant_config=fp8_w8a8_moe_quant_config( + per_act_token_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ), ) # Check results diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 4b58a28eed125..7a0feb6a20795 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -34,18 +35,22 @@ def triton_moe( per_act_token_quant=False, block_shape: Optional[list[int]] = None, ) -> torch.Tensor: + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + return fused_experts(a, w1, w2, topk_weight, topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_channel_quant=per_act_token_quant, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + quant_config=quant_config) def batched_moe( @@ -64,6 +69,16 @@ def batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, num_dispatchers=1, @@ -72,21 +87,11 @@ def batched_moe( BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + quant_config=quant_config, ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts(a, w1, w2, topk_weight, topk_ids) def naive_batched_moe( @@ -105,6 +110,16 @@ def naive_batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, num_dispatchers=1, @@ -113,21 +128,11 @@ def naive_batched_moe( NaiveBatchedExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + quant_config=quant_config, ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts(a, w1, w2, topk_weight, topk_ids) def chunk_scales(scales: Optional[torch.Tensor], start: int, @@ -216,7 +221,7 @@ def make_test_weight( in_dtype: torch.dtype = torch.bfloat16, quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 @@ -228,7 +233,7 @@ def make_test_weight( w_gs_l = [None] * e for idx in range(e): w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( - w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) + w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape) w = torch.stack(w_l) w_s = torch.stack(w_s_l) @@ -258,16 +263,16 @@ def make_test_weights( in_dtype: torch.dtype = torch.bfloat16, quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, ) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]]: return ( make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + per_out_ch_quant), make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + per_out_ch_quant), ) @@ -285,6 +290,76 @@ def per_token_cast_to_fp8( return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) +def make_test_quant_config( + e: int, + n: int, + k: int, + in_dtype: torch.dtype, + quant_dtype: Union[torch.dtype, str, None] = None, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]: + (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype, + quant_dtype, + per_out_ch_quant=per_act_token_quant, + block_shape=block_shape, + ) + + # Hacky/trivial scales for nvfp4. + a1_gscale: Optional[torch.Tensor] = None + a2_gscale: Optional[torch.Tensor] = None + if quant_dtype == "nvfp4": + a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32) + a1_scale = a1_gscale + a2_scale = a2_gscale + else: + a1_scale = None + a2_scale = None + + return w1, w2, FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_s, + w2_scale=w2_s, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + a1_scale=a1_scale, + a2_scale=a2_scale, + # TODO: make sure this is handled properly + g1_alphas=(1 / w1_gs) if w1_gs is not None else None, + g2_alphas=(1 / w2_gs) if w2_gs is not None else None, + ) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + renormalize: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk, + renormalize) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=quant_config) + + # CustomOp? class BaselineMM(torch.nn.Module): diff --git a/tests/kernels/quantization/test_int8_kernel.py b/tests/kernels/quantization/test_int8_kernel.py index dc5fecbf4ccc8..f2271e6be5420 100644 --- a/tests/kernels/quantization/test_int8_kernel.py +++ b/tests/kernels/quantization/test_int8_kernel.py @@ -8,7 +8,8 @@ import pytest import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_quant_int8) from vllm.platforms import current_platform @@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): return C.reshape(origin_C_shape).to(output_dtype) -def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, + topk_ids): """This function performs fused moe with per-column int8 quantization using native torch.""" @@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) # Calculate routing - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) # Process each expert @@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(score, topk) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, + topk_weights, topk_ids) + + quant_config = FusedMoEQuantConfig.make( + torch.int8, + per_act_token_quant=True, + block_shape=None, + w1_scale=w1_s, + w2_scale=w2_s, + ) + + out = fused_experts( a, w1, w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, # Using int8-w8a8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization + topk_weights, + topk_ids, + quant_config=quant_config, ) # Check results diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3007643d7a288..6730f051e3d71 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -36,6 +37,7 @@ __all__ = [ "FusedMoEPermuteExpertsUnpermute", "FusedMoEActivationFormat", "FusedMoEPrepareAndFinalize", + "activation_without_mul", "override_config", "get_config", ] @@ -43,7 +45,6 @@ __all__ = [ if HAS_TRITON: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa - import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 @@ -56,13 +57,12 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, fused_experts, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + TritonExperts, fused_experts, fused_topk, get_config_file_name, + grouped_topk) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) __all__ += [ - "fused_moe", "fused_topk", "fused_experts", "get_config_file_name", diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 0ab6355f41565..e9dfb22bea27b 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -8,6 +8,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache @@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda( class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - # The Deep Gemm kernels only support block size of 128 - DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] - def __init__(self, - max_num_tokens: int, - num_dispatchers: int, - block_shape: list[int], - per_act_token_quant=False): + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + ): """ max_num_tokens: Maximum number of tokens from a DP Rank num_dispatchers: The number of DP dispatchers. - block_shape: Block quantization block shape. - per_act_token_quant: Per activation token quantization flag. + quant_config: Quantization configuration """ - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE + super().__init__(quant_config) + assert self.block_shape == deep_gemm_block_shape() self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens - fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), + fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale), workspace1, expert_num_tokens, expected_m) a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( workspace1, expert_num_tokens) - fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, - expert_num_tokens, expected_m) + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale), + output, expert_num_tokens, expected_m) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 89d7412ee2236..8b9070f098898 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, - max_num_tokens: int, - num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, - allow_deep_gemm: bool = False): - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - )) + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + allow_deep_gemm: bool = False, + ): + super().__init__(quant_config) self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=num_dispatchers, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape, + quant_config=self.quant_config, ) - self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 - and self.block_shape - == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + self.allow_deep_gemm = (allow_deep_gemm + and self.quant_config.use_fp8_w8a8 and + self.block_shape == deep_gemm_block_shape()) self.batched_deep_gemm_experts = BatchedDeepGemmExperts( max_num_tokens=max_num_tokens, num_dispatchers=num_dispatchers, - block_shape=self.block_shape, # type: ignore[arg-type] + quant_config=self.quant_config, ) if self.allow_deep_gemm else None assert (self.batched_deep_gemm_experts is not None @@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, - activation, global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, - workspace2, expert_tokens_meta, + activation, global_num_experts, expert_map, a1q_scale, + workspace13, workspace2, expert_tokens_meta, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 0b501cd87fb5d..742df3dbdc6af 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,103 +1,322 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import torch -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) import vllm.envs as envs from vllm.config import ParallelConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.utils import cdiv +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.utils import cdiv, has_triton_kernels from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +if TYPE_CHECKING and has_triton_kernels: + from triton_kernels.matmul_ogs import PrecisionConfig + logger = init_logger(__name__) -def _get_quant_config_quantization_args( - quant_config: Optional[QuantizationConfig], - prop_name: str, -) -> Optional[QuantizationArgs]: - if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') - and "Linear" in quant_config.target_scheme_map and - "input_activations" in quant_config.target_scheme_map["Linear"]): - return quant_config.target_scheme_map["Linear"].get(prop_name) - else: - return None - - -def get_quant_config_input_quant( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - return _get_quant_config_quantization_args(quant_config, - "input_activations") - - -def get_quant_config_weight_quant( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - return _get_quant_config_quantization_args(quant_config, "weights") - - -def get_config_quant_dtype( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - use_mxfp4_w4a4: bool, -) -> Union[None, torch.dtype, str]: +def _get_config_dtype_str( + dtype: torch.dtype, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, +) -> Optional[str]: + """ + Return a string used to construct the filename that contains the + tuning info for a particular quantization scheme. See + try_get_optimal_moe_config in fused_moe.py. + """ if use_fp8_w8a8: - return torch.float8_e4m3fn - elif use_int8_w8a8: - return torch.int8 + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" elif use_mxfp4_w4a4: - return "mxfp4" + return "mxfp4_w4a4" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" return None +def _quant_flags_to_group_shape( + quant_dtype: Union[torch.dtype, str, None], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]], +) -> tuple[Optional[GroupShape], Optional[GroupShape]]: + """ + Convert MoE quantization flags into more generic GroupShapes. + """ + a_shape: Optional[GroupShape] + w_shape: Optional[GroupShape] + if block_shape is not None: + assert not per_act_token_quant + assert not per_out_ch_quant + # TODO(bnell): this is not quite right for activations since first + # dim should be 1. + a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + else: + w_shape = None + a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR + + if per_act_token_quant: + a_shape = GroupShape.PER_TOKEN + + if per_out_ch_quant: + w_shape = GroupShape.PER_TOKEN + + return a_shape, w_shape + + +@dataclass +class FusedMoEQuantDesc: + """ + A quantization descriptor for fused MoE ops. This class can describe + either activations or weights. + """ + + # The quantized type of this parameters. None means unquantized or + # already quantized. + # TODO (bnell): use scalar_type instead of Union. + dtype: Union[torch.dtype, str, None] = None + + # A field that describes the quantization group shape, from quant_utils.py. + # * (-1, -1) for per-tensor quantization + # * (1, -1) for per-row quantization + # * (-1, 1) for per-column quantization + # * (128, 128) for 128x128 deepseek style block quantization + # * (1, 128) for deepseek style activation quantization + # (i.e. per-token-per-group) + shape: Optional[GroupShape] = None + + # Quantization scales. + # TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc? + scale: Union[torch.Tensor, "PrecisionConfig", None] = None + + # Quantization alphas or gscales, used for nvfp4 types. + # TODO(bnell): put some of these in subclasses + alpha_or_gscale: Optional[torch.Tensor] = None + + # Zero points for int4/int8 types + zp: Optional[torch.Tensor] = None + + # Biases for GPT triton MoE + bias: Optional[torch.Tensor] = None + + +# TODO(bnell): have subclasses for specific moe methods? +# e.g. for specific arguments bias, precision, etc. @dataclass class FusedMoEQuantConfig: - # The post quantization activation type. - # TODO (bnell): use scalar_type instead of Union. - quant_dtype: Union[torch.dtype, str, None] = None - per_act_token_quant: bool = False - per_out_ch_quant: bool = False - block_shape: Optional[list[int]] = None + """ + The FusedMoEQuantConfig contains all the quantization parameters for + a single FusedMoEMethodBase operation. It consists of four + FusedMoEQuantDescs, one for each activation and set of weights. - # TODO: add col major flag? - # add detailed quant info for input, intermediates, weights, etc? + Each FusedMoEMethodBase must implement a get_fused_moe_quant_config + method to construct a FusedMoEQuantConfig for use with that class. + + FusedMoEQuant configs are only used for modular kernels, fused_experts + (from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and + triton_kernel_moe_forward. Other MoE methods can ignore the + FusedMoEQuantConfig (for now) and hardcode it to None. + + There are currently some restrictions on what can be expressed: + - Most MoE ops only support similar quantization strategies for + each parameter, e.g. both weights must have the same GroupShape + and both activations must share the same GroupShape. One exception to + this is the cutlass moe which allows per channel quantization on the + outputs. Note: this restrictions are not always rigorously checked. + - Not all fused MoE functions support all the parameters, e.g. zero points, + global scales, alphas and biases are not universally supported. + - Fully general GroupShapes are not allowed. Activations only support + per token, per tensor or K-blocked. + - Weights are not required to have a GroupShape since they have already + been quantized. + + Other notes: + - PrecisionConfigs are specific to GPT OSS Triton. + - As a follow up it would probably make sense to subclass FusedMoEQuantDesc + or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses + so that only the required quantization parameters are used/stored. + """ + + # TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking + _a1: FusedMoEQuantDesc + _a2: FusedMoEQuantDesc + _w1: FusedMoEQuantDesc + _w2: FusedMoEQuantDesc def __post_init__(self): assert (not self.per_act_token_quant or self.block_shape is None), "illegal quantization" + # + # Convenience accessors for various properties. + # + + @property + def quant_dtype(self) -> Union[torch.dtype, str, None]: + return self._a1.dtype + @property def is_quantized(self) -> bool: return self.quant_dtype is not None @property def is_per_act_token(self) -> bool: - return self.per_act_token_quant + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_act_token_quant(self) -> bool: + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_out_ch_quant(self) -> bool: + return self._w1.shape == GroupShape.PER_TOKEN + + @property + def is_per_tensor(self) -> bool: + return self._a1.shape == GroupShape.PER_TENSOR + + @property + def block_shape(self) -> Optional[list[int]]: + if (self._a1.shape is not None + and self._a1.shape != GroupShape.PER_TENSOR + and self._a1.shape != GroupShape.PER_TOKEN): + return [self._a1.shape.row, self._a1.shape.col] + else: + return None @property def is_block_quantized(self) -> bool: return self.block_shape is not None @property - def is_per_tensor(self) -> bool: - return not self.per_act_token_quant and self.block_shape is None + def a1_scale(self) -> Optional[torch.Tensor]: + assert self._a1.scale is None or isinstance(self._a1.scale, + torch.Tensor) + return self._a1.scale + + @property + def a1_gscale(self) -> Optional[torch.Tensor]: + return self._a1.alpha_or_gscale + + @property + def a2_scale(self) -> Optional[torch.Tensor]: + assert self._a2.scale is None or isinstance(self._a2.scale, + torch.Tensor) + return self._a2.scale + + @property + def a2_gscale(self) -> Optional[torch.Tensor]: + return self._a2.alpha_or_gscale + + @property + def w1_scale(self) -> Optional[torch.Tensor]: + assert self._w1.scale is None or isinstance(self._w1.scale, + torch.Tensor) + return self._w1.scale + + @property + def w1_zp(self) -> Optional[torch.Tensor]: + return self._w1.zp + + @property + def w1_bias(self) -> Optional[torch.Tensor]: + return self._w1.bias + + @property + def w1_precision(self) -> Optional["PrecisionConfig"]: + assert self._w1.scale is None or isinstance(self._w1.scale, + PrecisionConfig) + return self._w1.scale + + @property + def g1_alphas(self) -> Optional[torch.Tensor]: + return self._w1.alpha_or_gscale + + @property + def w2_scale(self) -> Optional[torch.Tensor]: + assert self._w2.scale is None or isinstance(self._w2.scale, + torch.Tensor) + return self._w2.scale + + @property + def w2_zp(self) -> Optional[torch.Tensor]: + return self._w2.zp + + @property + def w2_bias(self) -> Optional[torch.Tensor]: + return self._w2.bias + + @property + def w2_precision(self) -> Optional["PrecisionConfig"]: + assert self._w2.scale is None or isinstance(self._w2.scale, + PrecisionConfig) + return self._w2.scale + + @property + def g2_alphas(self) -> Optional[torch.Tensor]: + return self._w2.alpha_or_gscale + + @property + def use_fp8_w8a8(self) -> bool: + return self.quant_dtype == torch.float8_e4m3fn + + @property + def use_int8_w8a8(self) -> bool: + return self.quant_dtype == torch.int8 + + @property + def use_int8_w8a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == torch.int8) + + @property + def use_int4_w4a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == "int4") + + @property + def use_mxfp4_w4a4(self) -> bool: + return self.quant_dtype == "mxfp4" + + @property + def use_nvfp4_w4a4(self) -> bool: + return self.quant_dtype == "nvfp4" + + def config_name(self, dtype: torch.dtype) -> Optional[str]: + """ + Return a string used to construct the filename that contains the + tuning info for a particular quantization scheme. See + try_get_optimal_moe_config in fused_moe.py. + """ + return _get_config_dtype_str( + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + use_mxfp4_w4a4=self.use_mxfp4_w4a4, + dtype=dtype, + ) def scale_shape( self, max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int]]: + """ + Construct the proper activation scale shape for this + config. + """ if self.is_quantized: if self.is_block_quantized: assert self.block_shape is not None @@ -117,6 +336,10 @@ class FusedMoEQuantConfig: max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int, int]]: + """ + Construct the proper activation batched scale shape for this + config, e.g. (num experts, *scale_shape). + """ if self.is_quantized: scale_shape = self.scale_shape(max_tokens, hidden_dim) assert scale_shape is not None @@ -126,38 +349,218 @@ class FusedMoEQuantConfig: @staticmethod def make( - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + quant_dtype: Union[torch.dtype, str, None] = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: Optional[list[int]] = None, + w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + g1_alphas: Optional[torch.Tensor] = None, + g2_alphas: Optional[torch.Tensor] = None, + a1_gscale: Optional[torch.Tensor] = None, + a2_gscale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, ) -> "FusedMoEQuantConfig": - assert sum([ - int(flag) for flag in [ - use_fp8_w8a8, - use_int8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - use_mxfp4_w4a4, - ] - ]) <= 1, "Quantization flags are mutually exclusive." + """ + General builder function for a FusedMoEQuantConfig. + - quant_dtype: Optional quantization type. None if activations are + unquantized or quantized prior to calling. Note: "nvfp4" and + "mxfp4" are the only valid string values for quant_dtype. + - per_act_token_quant: Activations have per token quantization. + - per_out_ch_quant: Outputs have per channel quantization. (only + for cutlass). + - block_shape: Optional block size for block-wise quantization. + Incompatible with per_act_token and per_out_ch quant. + - w1_scale: Optional scale to be used for w1. + - w2_scale: Optional scale to be used for w2. + - a1_scale: Optional scale to be used for a1. + - a2_scale: Optional scale to be used for a2. + - g1_alphas: Optional global quantization scales for w1 (for nvfp4). + - g2_alphas: Optional global quantization scales for w2 (for nvfp4). + - a1_gscale: Optional global quantization scales for a1 (for nvfp4). + - a2_gscale: Optional global quantization scales for a2 (for nvfp4). + - w1_bias: Optional biases for w1 (GPT OSS Triton). + - w2_bias: Optional biases for w1 (GPT OSS Triton). + - w1_zp: Optional w1 zero points for int4/int8 quantization. + - w2_zp: Optional w2 zero points for int4/int8 quantization. + """ + assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4" + or quant_dtype == "mxfp4") + a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape) + quant_config = FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale), + _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale), + _w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas, + w1_zp, w1_bias), + _w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas, + w2_zp, w2_bias), + ) + assert quant_config.per_act_token_quant == per_act_token_quant + assert quant_config.per_out_ch_quant == per_out_ch_quant + assert quant_config.block_shape == block_shape + return quant_config - quant_dtype = get_config_quant_dtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - ) - return FusedMoEQuantConfig( - quant_dtype, - per_act_token_quant, - per_out_ch_quant, - block_shape, - ) + +def fp8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for fp8 activations and fp8 weights. + """ + return FusedMoEQuantConfig.make(torch.float8_e4m3fn, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape) + + +def int8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + per_act_token_quant: bool = False, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for int8 activations and int8 weights. + """ + return FusedMoEQuantConfig.make( + torch.int8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=False, + block_shape=None, + ) + + +def mxfp4_w4a4_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and mxfp4 weights. + """ + return FusedMoEQuantConfig.make( + "mxfp4", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=block_shape, + ) + + +def nvfp4_moe_quant_config( + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and nvp4 weights. + """ + return FusedMoEQuantConfig.make( + "nvfp4", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=None, + ) + + +def int4_w4a16_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-bit float activations and int4 weights. + Note: Activations are pre-quantized. + """ + group_shape = GroupShape(*block_shape) if block_shape is not None else None + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(shape=group_shape), + _a2=FusedMoEQuantDesc(shape=group_shape), + _w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp), + _w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp), + ) + + +def int8_w8a16_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-bit float activations and int8 weights. + Note: Activations are pre-quantized. + """ + group_shape = GroupShape(*block_shape) if block_shape is not None else None + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(shape=group_shape), + _a2=FusedMoEQuantDesc(shape=group_shape), + _w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp), + _w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp), + ) + + +def biased_moe_quant_config( + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], +) -> FusedMoEQuantConfig: + """ + Construct a quant config for unquantized activations with biases. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc(bias=w1_bias), + _w2=FusedMoEQuantDesc(bias=w2_bias), + ) + + +# A FusedMoEQuantConfig constant for an unquantized MoE op. +FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make() @dataclass @@ -315,8 +718,6 @@ class FusedMoEConfig: # The activation type. in_dtype: torch.dtype - quant_config: Optional[FusedMoEQuantConfig] = None - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False @@ -328,34 +729,6 @@ class FusedMoEConfig: assert self.max_num_tokens > 0 - @property - def quant_dtype(self) -> Union[torch.dtype, str, None]: - if self.quant_config is not None: - return self.quant_config.quant_dtype - else: - return None - - @property - def block_shape(self) -> Optional[list[int]]: - if self.quant_config is not None: - return self.quant_config.block_shape - else: - return None - - @property - def per_act_token_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_act_token_quant - else: - return False - - @property - def per_out_ch_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_out_ch_quant - else: - return False - @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -401,97 +774,6 @@ class FusedMoEConfig: """ Whether to use FlashInfer cutlass kernels for NVFP4 MoE. """ - return (self.quant_config is not None - and self.quant_config.quant_dtype == "nvfp4" - and envs.VLLM_USE_FLASHINFER_MOE_FP4 + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutlass_fused_moe() and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") - - @staticmethod - def make( - num_experts: int, - experts_per_token: int, - hidden_dim: int, - num_local_experts: int, - moe_parallel_config: FusedMoEParallelConfig, - in_dtype: torch.dtype, - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config: Optional[Union[FusedMoEQuantConfig, - QuantizationConfig]] = None, - has_bias: bool = False, - ) -> "FusedMoEConfig": - - _quant_config: Optional[FusedMoEQuantConfig] = None - - if quant_config is not None and isinstance(quant_config, - QuantizationConfig): - if hasattr(quant_config, 'weight_block_size'): - block_shape = quant_config.weight_block_size - else: - block_shape = None - per_act_token_quant = False - per_out_ch_quant = False - quant_dtype: Union[torch.dtype, str, None] = None - - input_quant = get_quant_config_input_quant(quant_config) - weight_quant = get_quant_config_weight_quant(quant_config) - - if input_quant is not None: - per_act_token_quant = (input_quant.strategy - == QuantizationStrategy.TOKEN - if input_quant is not None else False) - - if input_quant.num_bits == 8: - if input_quant.type == QuantizationType.FLOAT: - quant_dtype = torch.float8_e4m3fn - elif input_quant.type == QuantizationType.INT: - quant_dtype = torch.int8 - - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - if quant_dtype is None and isinstance(quant_config, Fp8Config): - quant_dtype = torch.float8_e4m3fn - - from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Config) - if (quant_dtype is None and isinstance(quant_config, Mxfp4Config) - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): - quant_dtype = "mxfp8" - - from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptNvFp4Config) - if quant_dtype is None and isinstance(quant_config, - ModelOptNvFp4Config): - quant_dtype = "nvfp4" - - if weight_quant is not None: - per_out_ch_quant = ( - weight_quant.strategy == QuantizationStrategy.CHANNEL) - - if quant_dtype is not None: - _quant_config = FusedMoEQuantConfig( - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - ) - else: - _quant_config = FusedMoEQuantConfig() - if moe_parallel_config.dp_size > 1: - logger.warning_once("MoE DP setup unable to determine " - "quantization scheme or unsupported " - "quantization type. This model will " - "not run with DP enabled.") - else: - _quant_config = quant_config - - return FusedMoEConfig( - num_experts=num_experts, - experts_per_token=experts_per_token, - hidden_dim=hidden_dim, - num_local_experts=num_local_experts, - moe_parallel_config=moe_parallel_config, - in_dtype=in_dtype, - quant_config=_quant_config, - max_num_tokens=max_num_tokens, - has_bias=has_bias, - ) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 95d23ec0346c1..957ffca0d1246 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - )) + assert quant_config.use_fp8_w8a8 + super().__init__(quant_config) self.out_dtype = out_dtype self.ab_strides1 = ab_strides1 self.ab_strides2 = ab_strides2 @@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" - assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" expert_num_tokens = None if expert_tokens_meta is not None: @@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): in_dtype = hidden_states.dtype run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, - global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, - self.c_strides2, workspace13, workspace2, expert_num_tokens, + global_num_experts, expert_map, self.w1_scale, self.w2_scale, + a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2, + self.c_strides1, self.c_strides2, workspace13, workspace2, + expert_num_tokens, self.out_dtype if self.out_dtype is not None else in_dtype, self.per_act_token_quant, self.per_out_ch_quant, use_batched_format, topk_weights) @@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): def __init__( self, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): super().__init__( out_dtype, - per_act_token_quant, - per_out_ch_quant, ab_strides1, ab_strides2, c_strides1, c_strides2, - block_shape, + quant_config, ) @property @@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): max_experts_per_worker: int, num_dispatchers: int, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): super().__init__( out_dtype, - per_act_token_quant, - per_out_ch_quant, ab_strides1, ab_strides2, c_strides1, c_strides2, - block_shape, + quant_config, ) assert max_experts_per_worker > 0 self.max_experts_per_worker = max_experts_per_worker @@ -414,16 +395,12 @@ def cutlass_moe_fp8( w2_q: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - per_act_token: Optional[bool] = None, + quant_config: FusedMoEQuantConfig, activation: str = "silu", - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, @@ -475,10 +452,18 @@ def cutlass_moe_fp8( Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - if per_act_token is None: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - per_out_ch = w1_scale.numel() != w1_q.size(0) + assert quant_config is not None + + if quant_config.a1_scale is not None: + assert (quant_config.per_act_token_quant == + quant_config.a1_scale.numel() != 1) + if quant_config.a2_scale is not None: + assert (quant_config.per_act_token_quant == + quant_config.a2_scale.numel() != 1) + + assert (quant_config.w1_scale is None + or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) + == w1_q.size(1)))) num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( 0) @@ -487,12 +472,11 @@ def cutlass_moe_fp8( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( out_dtype=a.dtype, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_out_ch, ab_strides1=ab_strides1, ab_strides2=ab_strides2, c_strides1=c_strides1, c_strides2=c_strides2, + quant_config=quant_config, ), ) @@ -502,14 +486,9 @@ def cutlass_moe_fp8( w2_q, topk_weights, topk_ids, - False, - activation, - num_experts, - expert_map, - w1_scale, - w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + activation=activation, + global_num_experts=num_experts, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -542,7 +521,7 @@ def run_cutlass_moe_fp4( ) -> None: """ MoE implementation for FP4 Inputs - + # Gemm 1 a: Input tensor: [m, k] (half/bfloat16) a1_gscale: Activation scale per expert: [e] (float32) @@ -552,16 +531,16 @@ def run_cutlass_moe_fp4( full precision) w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) (Block size = 16 for NVFP4) - + # Gemm 2 a2_gscale: Activation scale per expert: [e] w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 - + topk_weights: [m, topk] dtype: float8 topk_ids: [m, topk] dtype: float8 - + m, n, k: Unquantized weight shapes, dtype: int e: number of experts, dtype: int @@ -652,42 +631,21 @@ def run_cutlass_moe_fp4( return +# Split into batched and non-batched class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, max_experts_per_worker: int, out_dtype: torch.dtype, - per_act_token_quant: bool, - per_out_ch_quant: bool, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, use_batched_format: bool = False, ): - super().__init__( - # NVFP4 requires two levels of quantization, which involves - # computing some scaling factors dynamically. This makes it - # incompatible with the typical prepare -> MoE -> finalize - # pipeline. Move the quantization logic into the MoE body. - FusedMoEQuantConfig( - quant_dtype=None, # skip quantization in prepare/finalize - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - )) + super().__init__(quant_config) self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype self.use_batched_format = use_batched_format - # TODO(bnell): put this stuff into quant config? - self.g1_alphas = g1_alphas - self.g2_alphas = g2_alphas - self.a1_gscale = a1_gscale - self.a2_gscale = a2_gscale - @property def activation_formats( self @@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, + a1q_scale: Optional[torch.Tensor], # unused workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): a=hidden_states, a1_gscale=self.a1_gscale, w1_fp4=w1, - w1_blockscale=w1_scale, + w1_blockscale=self.w1_scale, w1_alphas=self.g1_alphas, a2_gscale=self.a2_gscale, w2_fp4=w2, - w2_blockscale=w2_scale, + w2_blockscale=self.w2_scale, w2_alphas=self.g2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, @@ -788,14 +741,9 @@ def cutlass_moe_fp4( a: torch.Tensor, w1_fp4: torch.Tensor, w2_fp4: torch.Tensor, - w1_blockscale: torch.Tensor, - w2_blockscale: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + quant_config: FusedMoEQuantConfig, m: int, n: int, k: int, @@ -805,17 +753,31 @@ def cutlass_moe_fp4( assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") + + # TODO(bnell): this feels a bit hacky + # NVFP4 requires two levels of quantization, which involves + # computing some scaling factors dynamically. This makes it + # incompatible with the typical prepare -> MoE -> finalize + # pipeline. Move the quantization logic into the MoE body. + quant_config = FusedMoEQuantConfig.make( + quant_dtype=None, # skip quantization in prepare/finalize + per_act_token_quant=quant_config.per_act_token_quant, + per_out_ch_quant=quant_config.per_out_ch_quant, + block_shape=quant_config.block_shape, + g1_alphas=quant_config.g1_alphas, + g2_alphas=quant_config.g2_alphas, + a1_gscale=quant_config.a1_gscale, + a2_gscale=quant_config.a2_gscale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + ) + fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( - g1_alphas, - g2_alphas, - a1_gscale, - a2_gscale, max_experts_per_worker=e, out_dtype=a.dtype, - per_act_token_quant=False, - per_out_ch_quant=False, + quant_config=quant_config, use_batched_format=False, ), ) @@ -830,10 +792,6 @@ def cutlass_moe_fp4( activation="silu", global_num_experts=e, expert_map=None, - w1_scale=w1_blockscale, - w2_scale=w2_blockscale, - a1_scale=None, - a2_scale=None, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm( return True +# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8. def run_cutlass_block_scaled_fused_experts( a: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index c0bfda73eee0d..8830b95df7cf0 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from typing import Optional import torch @@ -9,9 +8,11 @@ from tqdm import tqdm import vllm.envs as env import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce) + compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute, + deepgemm_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous logger = init_logger(__name__) -@functools.cache -def deep_gemm_block_shape() -> list[int]: - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - block = dg.get_m_alignment_for_contiguous_layout() - return [block, block] - - def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: align = deep_gemm_block_shape()[0] return align <= M and N % align == 0 and K % align == 0 @@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=deep_gemm_block_shape(), - )) + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + assert quant_config.block_shape == deep_gemm_block_shape() + assert quant_config.quant_dtype == torch.float8_e4m3fn + assert not quant_config.per_act_token_quant + assert not quant_config.per_out_ch_quant @property def activation_formats( @@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - assert self.block_shape is not None assert a1q_scale is not None - assert w1_scale is not None - assert w2_scale is not None + assert self.a2_scale is None + assert self.block_shape is not None + assert self.w1_scale is not None + assert self.w2_scale is not None a1q = hidden_states _, N, K = w1.size() @@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): aq_out=a1q_perm) assert a1q.size(0) == M_sum - m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), + m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids) self.activation(activation, act_out, mm1_out.view(-1, N)) @@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): column_major_scales=True, out_q=quant_out) - m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), + m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids) if apply_router_weight_on_input: @@ -348,9 +336,16 @@ def deep_gemm_moe_fp8( Returns: - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=deep_gemm_block_shape()) + fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - DeepGemmExperts(), + DeepGemmExperts(quant_config), ) return fn( hidden_states, @@ -358,13 +353,9 @@ def deep_gemm_moe_fp8( w2, topk_weights, topk_ids, - inplace, - activation, - global_num_experts, - expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 92cbb1742974c..5d6b9c87a6b76 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # Quant and Dispatch a1q, a1q_scale = moe_kernel_quantize_input( a1, - a1_scale, + quant_config.a1_scale, quant_dtype=quant_config.quant_dtype, per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, @@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): else: a1q = a1 a1q_scale = None - a1_post_scale = a1_scale + a1_post_scale = quant_config.a1_scale return (lambda *args: None, self._do_dispatch(tokens=a1q, @@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - (_, receiver) = self.prepare_async(a1, a1_scale, a2_scale, - topk_weights, topk_ids, num_experts, - expert_map, + (_, receiver) = self.prepare_async(a1, topk_weights, topk_ids, + num_experts, expert_map, apply_router_weight_on_input, quant_config) return receiver() diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 61f8297f0f148..01df7770463d0 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def _do_quant( self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - a1_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None], - per_act_token_quant: bool, - block_shape: Optional[list[int]], + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - block_k = block_shape[1] if block_shape is not None else None if self.use_fp8_dispatch: + block_k = quant_config.block_shape[ + 1] if quant_config.block_shape is not None else None if block_k == DEEPEP_QUANT_BLOCK_SIZE: # DeepEP kernels did the quantization for us. x, x_scales = x @@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # TODO (varun): Optimization - Use a batched version of quant x = x.view((-1, hidden_dim)) - x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, - per_act_token_quant, - block_shape) + x, x_scales = moe_kernel_quantize_input( + x, quant_config.a1_scale, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) x = x.view((num_experts, -1, hidden_dim)) - if quant_dtype is not None: + if quant_config.quant_dtype is not None: assert x_scales is not None x_scales = normalize_batched_scales_shape(x_scales, num_experts) @@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): assert hidden_size % 128 == 0, \ "DeepEP kernels quantize the inputs in blocks of shape 128" - has_per_token_scales = a1_scale.numel( - ) != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + has_per_token_scales = quant_config.a1_scale.numel( + ) != 1 if quant_config.a1_scale is not None else ( + quant_config.a2_scale.numel() != 1 + if quant_config.a2_scale is not None else False) assert not has_per_token_scales, ( "low_latency kernels doesn't support dispatching per-token scales") @@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return_recv_hook=True) self.handles[a2a_idx] = handle - return (hook, lambda: self._receiver(expert_x, expert_num_tokens, - a1_scale, a1.dtype, quant_config)) + return ( + hook, + lambda: self._receiver(expert_x, expert_num_tokens, quant_config. + a1_scale, a1.dtype, quant_config)) def _receiver( self, expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], expert_num_tokens: torch.Tensor, - a1_scale, - a1_dtype, + a1_scale: Optional[torch.Tensor], + a1_dtype: torch.dtype, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) + expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, + quant_config) expert_tokens_meta = mk.ExpertTokensMetadata( expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) @@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - hook, receiver = self.prepare_async(a1, a1_scale, a2_scale, - topk_weights, topk_ids, + hook, receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts, expert_map, apply_router_weight_on_input, quant_config) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index feab3f74cac53..6eeec18a6ec87 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Optional import torch @@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, out_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None], + quant_config: FusedMoEQuantConfig, ep_rank: int = 0, ep_size: int = 1, tp_rank: int = 0, tp_size: int = 1, ): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=quant_dtype, - per_act_token_quant=False, - block_shape=None, - )) - assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( + super().__init__(quant_config) + assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( "Only nvfp4,fp8 quantization are currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank self.tp_size = tp_size - self.g1_alphas = g1_alphas - self.g2_alphas = g2_alphas - self.a1_gscale = a1_gscale - self.a2_gscale = a2_gscale self.out_dtype = out_dtype @property @@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], # Not used workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): fc2_expert_weights = w2 else: # Ensure w1_scale and w2_scale are not None before calling view - assert w1_scale is not None and w2_scale is not None, ( + assert self.w1_scale is not None and self.w2_scale is not None, ( "w1_scale and w2_scale must not " "be None for FlashInferExperts") # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ self.a1_gscale, - w1_scale.view(torch.int32), + self.w1_scale.view(torch.int32), self.g1_alphas, self.a2_gscale, - w2_scale.view(torch.int32), + self.w2_scale.view(torch.int32), self.g2_alphas, ] # FlashInfer API requires weight to be long for nvfp4 @@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, + quant_config: FusedMoEQuantConfig, inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, @@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4( ) -> torch.Tensor: fused_experts = mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, - a1_gscale=a1_gscale), + FlashInferCutlassMoEPrepareAndFinalize(use_dp=False), FlashInferExperts( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, out_dtype=hidden_states.dtype, - quant_dtype="nvfp4", + quant_config=quant_config, )) return fused_experts( @@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4( activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 157cb36d4ffd3..8c7eff59f3cd1 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, use_dp: bool, - a1_gscale: Optional[torch.Tensor], num_dispatchers: int = 1, ): super().__init__() self.num_dispatchers_ = num_dispatchers self.use_dp = use_dp - self.a1_gscale = a1_gscale self.local_tokens = None @property @@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], # Not used - a2_scale: Optional[torch.Tensor], # Not used topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, - # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1q, a1q_scale = moe_kernel_quantize_input( a1, - self.a1_gscale, + quant_config.a1_gscale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py new file mode 100644 index 0000000000000..e358143fac7c7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import List # noqa: UP035 +from typing import Optional + +import torch + +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + calculate_tile_tokens_dim) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import direct_register_custom_op + + +def flashinfer_fused_moe_blockscale_fp8( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: List[int], #noqa: UP006 + routed_scaling: float = 1.0) -> torch.Tensor: + from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + assert top_k <= global_num_experts + assert top_k <= 8 + assert topk_group <= 4 + assert global_num_experts > num_expert_group + assert global_num_experts % num_expert_group == 0 + assert global_num_experts % 4 == 0 + assert top_k < (topk_group * global_num_experts / num_expert_group) + assert block_shape == [128, 128] + + a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + return flashinfer_trtllm_fp8_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale_inv, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale_inv, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling, + tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, + global_num_experts), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) + + +def flashinfer_fused_moe_blockscale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0) -> torch.Tensor: + return torch.empty_like(x) + + +# TODO(bnell): Does this really need to be a torch.op? +direct_register_custom_op( + op_name="flashinfer_fused_moe_blockscale_fp8", + op_func=flashinfer_fused_moe_blockscale_fp8, + mutates_args=[], + fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + +def flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: + num_expert_group = num_expert_group if num_expert_group is not None else 0 + topk_group = topk_group if topk_group is not None else 0 + + quant_hidden_states, _ = moe_kernel_quantize_input( + hidden_states, + input_scale, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False) + + from vllm.utils.flashinfer import ( + flashinfer_trtllm_fp8_per_tensor_scale_moe) + return flashinfer_trtllm_fp8_per_tensor_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=quant_hidden_states, + gemm1_weights=gemm1_weights, + output1_scales_scalar=output1_scales_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + gemm2_weights=gemm2_weights, + output2_scales_scalar=output2_scales_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], + top_k, num_experts), + routing_method_type=routing_method_type) + + +def flashinfer_fused_moe_per_tensor_scale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +# TODO(bnell): Does this really need to be a torch.op? +direct_register_custom_op( + op_name="flashinfer_fused_moe_per_tensor_scale_fp8", + op_func=flashinfer_fused_moe_per_tensor_scale_fp8, + mutates_args=["hidden_states"], + fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 88063668e9188..fe6ac458a9593 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -8,7 +8,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, try_get_optimal_moe_config) + try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched) from vllm.model_executor.layers.fused_moe.utils import ( @@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): dtype=torch.float32, device=a1.device) else: - assert a1_scale is None + assert quant_config.a1_scale is None b_a1_scale = None first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts - a1_scale = normalize_scales_shape(a1_scale) - a2_scale = normalize_scales_shape(a2_scale) + a1_scale = normalize_scales_shape(quant_config.a1_scale) for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() @@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): self, max_num_tokens: int, num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - assert not use_mxfp4_w4a4, "NYI" + super().__init__(quant_config) + assert not self.quant_config.use_int8_w8a8, "NYI" + assert not self.quant_config.use_int8_w8a16, "NYI" + assert not self.quant_config.use_int4_w4a16, "NYI" + assert not self.quant_config.use_mxfp4_w4a4, "NYI" self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): tmp = _resize_cache(workspace2, (num, N)) if self.quant_config.is_quantized: - assert a1q_scale is not None and w1_scale is not None + assert a1q_scale is not None and self.w1_scale is not None input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) - w1_dq = self.dequant(w1[expert], w1_scale[expert]) + w1_dq = self.dequant(w1[expert], self.w1_scale[expert]) input = input[:num] @ w1_dq.transpose(0, 1) else: input = hidden_states[expert, :num, :] @ w1[expert].transpose( @@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): self.activation(activation, tmp, input.to(tmp.dtype)) if self.quant_config.is_quantized: - assert w2_scale is not None - w2_dq = self.dequant(w2[expert], w2_scale[expert]) + assert self.w2_scale is not None + w2_dq = self.dequant(w2[expert], self.w2_scale[expert]) else: w2_dq = w2[expert] @@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self, max_num_tokens: int, num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - assert not use_mxfp4_w4a4, "NYI" + super().__init__(quant_config) + assert not self.quant_config.use_int8_w8a8, "NYI" + assert not self.quant_config.use_int8_w8a16, "NYI" + assert not self.quant_config.use_int4_w4a16, "NYI" + assert not self.quant_config.use_mxfp4_w4a4, "NYI" assert max_num_tokens > 0 assert num_dispatchers > 0 - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_mxfp4_w4a4 = use_mxfp4_w4a4 self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): # Check constraints. - if self.use_int4_w4a16: + if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: @@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): assert w1.size(0) == E assert w2.size(0) == E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - use_mxfp4_w4a4=self.use_mxfp4_w4a4, - dtype=hidden_states.dtype) + config_dtype = self.quant_config.config_name(hidden_states.dtype) config = try_get_optimal_moe_config( w1.size(), @@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - if self.use_fp8_w8a8: + # TODO(bnell): should this be done for any quantized type? + if self.quant_config.use_fp8_w8a8: intermediate_cache1.fill_(0) a1q_scale = normalize_batched_scales_shape(a1q_scale, E) @@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a1q_scale, - B_scale=w1_scale, - B_zp=w1_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + B_scale=self.w1_scale, + B_zp=self.w1_zp, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) @@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache1.view(-1, N)) qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, a2_scale, max_num_tokens, E, N, + intermediate_cache2, self.a2_scale, max_num_tokens, E, N, expert_num_tokens, self.quant_dtype, self.per_act_token_quant, self.block_shape) @@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a2q_scale, - B_scale=w2_scale, - B_zp=w2_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + B_scale=self.w2_scale, + B_zp=self.w2_zp, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 36c2ab8b2d5f3..d4de3f640865e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Fused MoE kernel.""" +"""Fused MoE Triton kernels.""" import functools import json import os # torch.compile needs typing.List. It will fail torch.library.infer_schema # otherwise from typing import List # noqa: UP035 -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -18,7 +18,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger # yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, get_config_quant_dtype) + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( _valid_cutlass_block_scaled_grouped_gemm, run_cutlass_block_scaled_fused_experts) @@ -32,11 +32,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - calculate_tile_tokens_dim) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + _resize_cache, activation_without_mul, moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform @@ -1049,87 +1045,66 @@ def fused_grouped_topk( return topk_values.to(torch.float32), topk_indices.to(torch.int32) -def get_config_dtype_str( - dtype: torch.dtype, - use_int4_w4a16: Optional[bool] = False, - use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False, - use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]: - if use_fp8_w8a8: - return "fp8_w8a8" - elif use_int8_w8a16: - return "int8_w8a16" - elif use_int4_w4a16: - return "int4_w4a16" - elif use_mxfp4_w4a4: - return "mxfp4_w4a4" - elif dtype == torch.float: - # avoiding cases where kernel fails when float32 MoE - # use fp16/bfloat16 configs - return "float32" - return None - - def inplace_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> None: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, #noqa: UP006 + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, is_act_and_mul, - apply_router_weight_on_input, use_fp8_w8a8, + activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) -def inplace_fused_experts_fake(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> None: +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, #noqa: UP006 + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> None: pass @@ -1143,175 +1118,6 @@ direct_register_custom_op( ) -def flashinfer_fused_moe_blockscale_fp8( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: List[int], #noqa: UP006 - routed_scaling: float = 1.0) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe - assert top_k <= global_num_experts - assert top_k <= 8 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 - assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) - assert block_shape == [128, 128] - - a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - return flashinfer_trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=w13_weight, - gemm1_weights_scale=w13_weight_scale_inv, - gemm2_weights=w2_weight, - gemm2_weights_scale=w2_weight_scale_inv, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, - global_num_experts), - routing_method_type=2, # DeepSeek-styled routing method - use_shuffled_weight=False, - ) - - -def flashinfer_fused_moe_blockscale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routed_scaling: float = 1.0) -> torch.Tensor: - return torch.empty_like(x) - - -direct_register_custom_op( - op_name="flashinfer_fused_moe_blockscale_fp8", - op_func=flashinfer_fused_moe_blockscale_fp8, - mutates_args=[], - fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - -def flashinfer_fused_moe_per_tensor_scale_fp8( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: - num_expert_group = num_expert_group if num_expert_group is not None else 0 - topk_group = topk_group if topk_group is not None else 0 - - quant_hidden_states, _ = moe_kernel_quantize_input( - hidden_states, - input_scale, - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False) - - from vllm.utils.flashinfer import ( - flashinfer_trtllm_fp8_per_tensor_scale_moe) - return flashinfer_trtllm_fp8_per_tensor_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=quant_hidden_states, - gemm1_weights=gemm1_weights, - output1_scales_scalar=output1_scales_scalar, - output1_scales_gate_scalar=output1_scales_gate_scalar, - gemm2_weights=gemm2_weights, - output2_scales_scalar=output2_scales_scalar, - num_experts=num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=local_expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling_factor, - use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], - top_k, num_experts), - routing_method_type=routing_method_type) - - -def flashinfer_fused_moe_per_tensor_scale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: - pass - - -direct_register_custom_op( - op_name="flashinfer_fused_moe_per_tensor_scale_fp8", - op_func=flashinfer_fused_moe_per_tensor_scale_fp8, - mutates_args=["hidden_states"], - fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1319,7 +1125,6 @@ def outplace_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", - is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1341,37 +1146,37 @@ def outplace_fused_experts( ) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, False, activation, - is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, - per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) + apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, + use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, + global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape, w1_bias, w2_bias) def outplace_fused_experts_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1403,45 +1208,36 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, +) -> torch.Tensor: + + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + use_fp8_w8a8 = quant_config.use_fp8_w8a8 + # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. # However, on B200, we use DeepGemm for all cases because they only support # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and use_fp8_w8a8 and + if (allow_deep_gemm and quant_config.use_fp8_w8a8 and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): + assert quant_config is not None assert apply_router_weight_on_input is False - assert is_act_and_mul, ( - "DeepGemm only supports is_act_and_mul=True for now.") return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, @@ -1452,22 +1248,23 @@ def fused_experts(hidden_states: torch.Tensor, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and _valid_cutlass_block_scaled_grouped_gemm( w1, w2, inplace, activation, apply_router_weight_on_input, expert_map)): + assert quant_config is not None return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1, w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, topk_weights=topk_weights, topk_ids=topk_ids) else: @@ -1478,26 +1275,49 @@ def fused_experts(hidden_states: torch.Tensor, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - is_act_and_mul=is_act_and_mul, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_channel_quant=per_channel_quant, + use_fp8_w8a8=quant_config.use_fp8_w8a8, + use_int8_w8a8=quant_config.use_int8_w8a8, + use_int8_w8a16=quant_config.use_int8_w8a16, + use_int4_w4a16=quant_config.use_int4_w4a16, + use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4, + per_channel_quant=quant_config.per_act_token_quant, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - w1_bias=w1_bias, - w2_bias=w2_bias, - ) + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + w1_zp=quant_config.w1_zp, + w2_zp=quant_config.w2_zp, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, + block_shape=quant_config.block_shape, + w1_bias=quant_config.w1_bias, + w2_bias=quant_config.w2_bias) + + +SILU_NO_MUL: str = activation_without_mul("silu") +GELU_NO_MUL: str = activation_without_mul("gelu") + + +def _get_config_quant_dtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_mxfp4_w4a4: bool, +) -> Union[None, torch.dtype, str]: + """ + Get the quantization type based on the quantization strategy flags. + We don't have a quant_config at this point so we need to work backwards. + A return type of None means no quantization is required because the + input is unquantized or has been quantized prior to calling + fused_experts_impl. + """ + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + elif use_mxfp4_w4a4: + return "mxfp4" + return None def fused_experts_impl( @@ -1508,7 +1328,6 @@ def fused_experts_impl( topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", - is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1557,17 +1376,18 @@ def fused_experts_impl( # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) - config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - dtype=hidden_states.dtype) - qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4) + config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, + dtype=hidden_states.dtype) + + # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are + # quantized prior to calling fused_experts. + quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_mxfp4_w4a4=use_mxfp4_w4a4) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1640,7 +1460,7 @@ def fused_experts_impl( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, - quant_dtype=qtype, + quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape) @@ -1671,30 +1491,29 @@ def fused_experts_impl( B_bias=w1_bias) # Activation function with multiplication - if activation == "silu" and is_act_and_mul: + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - elif activation == "gelu" and is_act_and_mul: + elif activation == "gelu": torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - elif activation == "swigluoai" and is_act_and_mul: + elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 torch.ops._C.swigluoai_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) # Activation function without multiplication - elif activation == "silu": + elif activation == SILU_NO_MUL: intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) - elif activation == "gelu": + elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}, " - f"with is_act_and_mul={is_act_and_mul}.") + raise ValueError(f"Unsupported FusedMoe activation: {activation}.") qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - quant_dtype=qtype, + quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape) @@ -1726,164 +1545,13 @@ def fused_experts_impl( return out_hidden_states -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - is_act_and_mul (bool): If True, use activation-and-mul function for - activation (self-gated activation), otherwise use activation function - for activation (ungated activation). - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseekv2 model uses grouped_topk - - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and - OCP MXFP4 activation to compute the inner products for w1 and w2. - Defaults to False. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. - - block_shape: (Optional[list[int]]): Optional block size for block-wise - quantization. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - if not is_act_and_mul: - assert inplace is False, ( - "is_act_and_mul=False is not supported with inplace=True") - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, - topk, renormalize, - num_expert_group, topk_group) - elif custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize) - - return fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - is_act_and_mul=is_act_and_mul, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - w1_bias=w1_bias, - w2_bias=w2_bias) - - class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_mxfp4_w4a4 = use_mxfp4_w4a4 + super().__init__(quant_config) @property def activation_formats( @@ -1929,19 +1597,14 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): # Check constraints. - if self.use_int4_w4a16: + if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: @@ -1964,17 +1627,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): if global_num_experts == -1: global_num_experts = E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - use_mxfp4_w4a4=self.use_mxfp4_w4a4, - dtype=hidden_states.dtype) - config = try_get_optimal_moe_config( w1.size(), w2.size(), top_k_num, - config_dtype, + self.quant_config.config_name(hidden_states.dtype), num_tokens, block_shape=self.block_shape, ) @@ -2008,8 +1665,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): w1, intermediate_cache1, a1q_scale, - w1_scale, - w1_zp, + self.w1_scale, + self.w1_zp, None, # topk_weights sorted_token_ids, expert_ids, @@ -2018,13 +1675,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): top_k_num, config, compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, - B_bias=None # TODO support B_bias + B_bias=self.w1_bias, ) self.activation(activation, intermediate_cache2, @@ -2033,7 +1690,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.quant_dtype, + intermediate_cache2, self.a2_scale, self.quant_dtype, self.per_act_token_quant, self.block_shape) invoke_fused_moe_kernel( @@ -2041,8 +1698,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): w2, intermediate_cache3, a2q_scale, - w2_scale, - w2_zp, + self.w2_scale, + self.w2_zp, topk_weights, sorted_token_ids, expert_ids, @@ -2051,36 +1708,21 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): 1, config, compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, - B_bias=None # TODO support B_bias + B_bias=self.w2_bias, ) ops.moe_sum(intermediate_cache3, output) def modular_triton_fused_moe( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - use_mxfp4_w4a4: bool, - per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, -) -> mk.FusedMoEModularKernel: + quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ), + TritonExperts(quant_config), ) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 312befe2c1d71..614a83ad1158c 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.utils import has_triton_kernels @@ -23,9 +25,6 @@ if has_triton_kernels(): "Failed to import Triton kernels. Please make sure your triton " "version is compatible.") -if TYPE_CHECKING: - from triton_kernels.matmul_ogs import PrecisionConfig - def triton_kernel_moe_forward( hidden_states: torch.Tensor, @@ -35,20 +34,10 @@ def triton_kernel_moe_forward( topk: int, renormalize: bool, activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, ) -> torch.Tensor: routing_data, gather_idx, scatter_idx = routing(gating_output, @@ -64,20 +53,10 @@ def triton_kernel_moe_forward( gather_idx, scatter_idx, activation=activation, + quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=w1_bias, - w2_bias=w2_bias, - w1_precision=w1_precision, - w2_precision=w2_precision, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + expert_map=expert_map) # This is a triton implementation of the fused_experts function @@ -90,28 +69,23 @@ def triton_kernel_fused_experts( gather_indx, # GatherIndx scatter_indx, # ScatterIndx activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, swiglu_alpha: float = 1.702, swiglu_limit: float = 7.0, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, + a1q_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 - assert w1_bias is None or w1_bias.dtype == torch.float32 - assert w2_bias is None or w2_bias.dtype == torch.float32 + assert (quant_config.w1_bias is None + or quant_config.w1_bias.dtype == torch.float32) + assert (quant_config.w2_bias is None + or quant_config.w2_bias.dtype == torch.float32) # Shape check, only check non-mxfp4 assert hidden_states.shape[-1] == w1.shape[-2] @@ -130,20 +104,20 @@ def triton_kernel_fused_experts( intermediate_cache1 = matmul_ogs( hidden_states, w1, - w1_bias, + quant_config.w1_bias, routing_data, gather_indx=gather_indx, - precision_config=w1_precision, + precision_config=quant_config.w1_precision, gammas=gammas if apply_router_weight_on_input else None, fused_activation=act) intermediate_cache3 = matmul_ogs( intermediate_cache1, w2, - w2_bias, + quant_config.w2_bias, routing_data, scatter_indx=scatter_indx, - precision_config=w2_precision, + precision_config=quant_config.w2_precision, gammas=None if apply_router_weight_on_input else gammas, y=output_tensor, ) @@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - quant_config, max_num_tokens: int, num_dispatchers: int, - w1_precision: "PrecisionConfig", - w2_precision: "PrecisionConfig", - w1_bias: Optional[torch.Tensor], - w2_bias: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, ): super().__init__(quant_config) self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers - self.w1_precision = w1_precision - self.w2_precision = w2_precision - self.w1_bias = w1_bias - self.w2_bias = w2_bias @property def activation_formats( @@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): hidden_states, w1, w2, - None, - None, - None, + routing_data=None, + gather_indx=None, + scatter_indx=None, activation=activation, + quant_config=self.quant_config, apply_router_weight_on_input=False, - use_fp8_w8a8=False, - per_channel_quant=False, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=self.w1_bias, - w2_bias=self.w2_bias, - w1_precision=self.w1_precision, - w2_precision=self.w2_precision, - a1_scale=a1q_scale, - a2_scale=a2_scale) + a1q_scale=a1q_scale) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d22bb253f4a72..ae3b67a2b84e6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -22,7 +22,8 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp # yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig) + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig, + FusedMoEQuantConfig, biased_moe_quant_config) # yapf: enable from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEModularKernel, @@ -78,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - # TODO(bnell): also pass quant_config? def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe - self.fused_experts: Optional[Callable] = None + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None + self.fused_experts: Optional[FusedMoEModularKernel] = None self.topk_indices_dtype = None @abstractmethod @@ -103,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase): @staticmethod def _maybe_make_prepare_finalize( - moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]: + moe: FusedMoEConfig, + quant_config: Optional[FusedMoEQuantConfig], + ) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + # TODO: could allow this now assert not moe.use_flashinfer_cutlass_kernels, \ "Must be created in modelopt.py" if moe.use_pplx_kernels: + assert quant_config is not None + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, moe.hidden_dim, moe.in_dtype, - moe.quant_dtype, - per_act_token_quant=moe.per_act_token_quant, - block_shape=moe.block_shape, + quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, ) all_to_all_args = dict( @@ -165,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ) elif moe.use_deepep_ll_kernels: + assert quant_config is not None all_to_all_args = dict( max_num_tokens_per_dp_rank=moe.max_num_tokens, token_hidden_size=moe.hidden_dim, @@ -174,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase): all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) - # Note : We may want to use FP8 dispatch even otherwise just to - # reduce datamovement - use_fp8_dispatch = (moe.quant_config is not None - and moe.quant_config.quant_dtype - == current_platform.fp8_dtype() - and moe.quant_config.block_shape - == DEEPEP_QUANT_BLOCK_SHAPE) + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE) prepare_finalize = DeepEPLLPrepareAndFinalize( handle, @@ -192,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): return prepare_finalize def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[FusedMoEPrepareAndFinalize]: - if moe.moe_parallel_config.use_all2all_kernels: - return FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.moe.moe_parallel_config.use_all2all_kernels: + return FusedMoEMethodBase._maybe_make_prepare_finalize( + self.moe, self.moe_quant_config) else: return None @@ -204,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): # prepare_communication_buffer_for_model. def init_prepare_finalize(self, layer: torch.nn.Module): assert self.moe is not None - prepare_finalize = self.maybe_make_prepare_finalize(self.moe) + + # We must get the quant config here so that the layer is + # completely initialized, i.e. all weights loaded and post + # processed. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + + prepare_finalize = self.maybe_make_prepare_finalize() if prepare_finalize is not None: logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, @@ -213,7 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): assert self.fused_experts is None, \ f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, self.moe, layer) + experts = self.select_gemm_impl(prepare_finalize, layer) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, @@ -223,7 +233,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate @@ -232,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase): f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") + @abstractmethod + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + raise NotImplementedError + @abstractmethod def apply( self, @@ -265,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -273,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: self.rocm_aiter_fused_experts = None # type: ignore + def maybe_make_prepare_finalize( + self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - # TODO(bnell): Remove. Every layer should have an moe config object. - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, ) else: logger.debug("TritonExperts %s", self.moe) - return TritonExperts() + return TritonExperts(self.moe_quant_config) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -303,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - if self.has_bias: + if self.moe.has_bias: w13_bias = torch.nn.Parameter(torch.zeros( num_experts, 2 * intermediate_size_per_partition, @@ -320,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - if self.has_bias: + if self.moe.has_bias: w2_bias = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, dtype=params_dtype), @@ -442,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): logical_replica_count=logical_replica_count, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.moe.has_bias: + return biased_moe_quant_config( + layer.w13_bias, + layer.w2_bias, + ) + else: + return FUSED_MOE_UNQUANTIZED_CONFIG + def forward_cuda( self, layer: torch.nn.Module, @@ -486,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: + assert self.fused_experts is None return self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -496,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) elif self.fused_experts is not None: - if self.has_bias: + if self.moe.has_bias: raise ValueError( "FusedMoEModularKernel does not support bias.") return self.fused_experts( @@ -517,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_bias=layer.w13_bias if self.has_bias else None, - w2_bias=layer.w2_bias if self.has_bias else None, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, + quant_config=self.moe_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, @@ -933,16 +963,18 @@ class FusedMoE(CustomOp): # since model_config is not set in the pytest test. model_dtype = params_dtype - moe = FusedMoEConfig.make(num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config=quant_config, - has_bias=has_bias) + moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + has_bias=has_bias, + ) self.moe_config = moe + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.quant_config = quant_config # Note: get_quant_method will look at the layer's local_num_experts @@ -990,6 +1022,9 @@ class FusedMoE(CustomOp): # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None + + # TODO(bnell): flashinfer uses non-batched format. + # Does it really need a batched buffer? if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_config.use_flashinfer_cutlass_kernels): @@ -1062,7 +1097,9 @@ class FusedMoE(CustomOp): @property def use_flashinfer_cutlass_kernels(self): - return self.moe_config.use_flashinfer_cutlass_kernels + return (self.moe_quant_config is not None + and self.moe_quant_config.quant_dtype == "nvfp4" + and self.moe_config.use_flashinfer_cutlass_kernels) def update_expert_map(self): # ep_size and ep_rank should already be updated @@ -1492,6 +1529,11 @@ class FusedMoE(CustomOp): self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] + def ensure_moe_quant_config(self): + if self.quant_method.moe_quant_config is None: + self.quant_method.moe_quant_config = ( + self.quant_method.get_fused_moe_quant_config(self)) + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -1711,6 +1753,8 @@ class FusedMoE(CustomOp): assert ( self.batched_router_logits.size(-1) == full_router_logits.size(-1)) + self.ensure_moe_quant_config() + full_fused_final_hidden_states = torch.empty_like(full_hidden_states) if self.shared_experts is not None: full_shared_final_hidden_states = torch.empty_like( @@ -1825,14 +1869,17 @@ class FusedMoE(CustomOp): router_logits: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.quant_method is not None + + self.ensure_moe_quant_config() + # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. - use_flashinfer_cutlass_kernels = ( - self.dp_size > 1 - and self.moe_config.use_flashinfer_cutlass_kernels) + _use_flashinfer_cutlass_kernels = (self.dp_size > 1 and + self.use_flashinfer_cutlass_kernels) + if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or use_flashinfer_cutlass_kernels): + or _use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index efaa9cc058e41..58cd0294c8c44 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC): """ Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. - - a1_scale: Optional scales for a1 - - a2_scale: Optional scales for the second MoE gemm. Required to make - sure the quantization is consistent for both gemms. - topk_ids: The topk ids. - topk_weights: The topk weights. - num_experts: The total number of experts in the global expert space. @@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC): space to the local expert space of the expert parallel shard. - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. + - quant_config: Quantization info provided by the fused experts. Returns a tuple of: - quantized + dispatched a. - - quantized + dispatched a1_scales. + - Optional quantized + dispatched a1_scales. - Optional ExpertTokensMetadata containing gpu/cpu tensors as big as the number of local experts with the information about the number of tokens assigned to each local expert. @@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC): def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC): raise NotImplementedError +# TODO: add supported activations method (return string) class FusedMoEPermuteExpertsUnpermute(ABC): """ An abstract base class for the [Permute-Experts-Unpermute] step described @@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def __init__( self, - quant_config: Optional[FusedMoEQuantConfig], + quant_config: FusedMoEQuantConfig, ): - if quant_config is not None: - self.quant_config = quant_config - else: - self.quant_config = FusedMoEQuantConfig() + """ + quant_config: Quantization parameters for this experts instance. + """ + self.quant_config = quant_config @property @abstractmethod @@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ raise NotImplementedError + # + # Various helpers for accessing quantization parameters from the + # quant_config. + # + @property def quant_dtype(self) -> Optional[torch.dtype]: return self.quant_config.quant_dtype @@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def per_out_ch_quant(self) -> bool: return self.quant_config.per_out_ch_quant + @property + def a1_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.a1_scale + + @property + def a2_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.a2_scale + + @property + def a1_gscale(self) -> Optional[torch.Tensor]: + return self.quant_config.a1_gscale + + @property + def a2_gscale(self) -> Optional[torch.Tensor]: + return self.quant_config.a2_gscale + + @property + def w1_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_scale + + @property + def w2_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_scale + + @property + def w1_zp(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_zp + + @property + def w2_zp(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_zp + + @property + def w1_bias(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_bias + + @property + def w2_bias(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_bias + + @property + def g1_alphas(self) -> Optional[torch.Tensor]: + return self.quant_config.g1_alphas + + @property + def g2_alphas(self) -> Optional[torch.Tensor]: + return self.quant_config.g2_alphas + # TODO (bnell): make this return a CHUNK_SIZE or None instead? @abstractmethod def supports_chunking(self) -> bool: @@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], @@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights: A map of row to expert weights. Some implementations - choose to do weight application. + choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. @@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. - - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be - used for a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + used for a1. Result of quantization from prepare/finalize and not + from the FusedMoEQuantConfig. - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation @@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts: int, local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, ) -> torch.Tensor: @@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=a1q_scale, - a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, @@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts: int, local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, ) -> torch.Tensor: @@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts=global_num_experts, local_num_experts=local_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=a1q_scale, - a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module): Optional[torch.Tensor], torch.Tensor, torch.Tensor]: s = chunk_idx * CHUNK_SIZE e = min(s + CHUNK_SIZE, M) - return (a1q[s:e], _chunk_scales(a1q_scale, s, e), - _chunk_scales(a2_scale, s, - e), topk_ids[s:e], topk_weights[s:e]) + return ( + a1q[s:e], + _chunk_scales(a1q_scale, s, e), + _chunk_scales(self.fused_experts.a2_scale, s, e), + topk_ids[s:e], + topk_weights[s:e], + ) def slice_output_tensor(chunk_idx: int) -> torch.Tensor: assert fused_out.size(0) % M == 0, ( @@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts=global_num_experts, local_num_experts=local_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=c_a1q_scale, - a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module): activation: str = "silu", global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ @@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module): - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. - - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. @@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module): (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, - a1_scale, - a2_scale, topk_weights, topk_ids, global_num_experts, @@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module): dbo_maybe_run_recv_hook() hook, receiver = self.prepare_finalize.prepare_async( a1, - a1_scale, - a2_scale, topk_weights, topk_ids, global_num_experts, @@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts=global_num_experts, local_num_experts=local_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=a1q_scale, - a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index b8c1c14317c46..32d12476dd01a 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): repeat_cols = 4 repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) + # TODO(bnell): always pass quant_config.a1_scale? a1q, a1q_scale = moe_kernel_quantize_input( - a1, (None if quant_config.per_act_token_quant else a1_scale), + a1, (None if quant_config.per_act_token_quant else + quant_config.a1_scale), quant_dtype=quant_config.quant_dtype, per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape) @@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ) -> mk.PrepareResultType: hook, receiver = self.prepare_async( a1, - a1_scale, - a2_scale, topk_weights, topk_ids, num_experts, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index bd9f7d4a06b17..588e5de865dd9 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, @@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): a1.mul_(topk_weights.to(a1.dtype)) a1q, a1q_scale = moe_kernel_quantize_input( - a1, a1_scale, quant_config.quant_dtype, + a1, quant_config.a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 13c3ab4f06dd1..f4972ff5f9cb0 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -7,6 +7,8 @@ from typing import Optional import torch from vllm import envs +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk( def rocm_aiter_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + expert_map: Optional[torch.Tensor] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, +) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG activation_method = (ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU) @@ -333,7 +332,8 @@ def rocm_aiter_fused_experts( expert_mask = None # w8a8 per-channel quantization - if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + if (quant_config.per_act_token_quant and apply_router_weight_on_input + and quant_config.use_fp8_w8a8): # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # This applies topk_weights on the GEMM output of the first FC layer # rather than the second FC. @@ -349,8 +349,8 @@ def rocm_aiter_fused_experts( w2, topk_weights, topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, + fc1_scale=quant_config.w1_scale, + fc2_scale=quant_config.w2_scale, fc1_smooth_scale=None, fc2_smooth_scale=None, a16=False, @@ -362,14 +362,14 @@ def rocm_aiter_fused_experts( quant_method = QuantMethod.NO.value # w8a8 block-scaled - if block_shape is not None and use_fp8_w8a8: + if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is\ not supported for block scaled moe") - assert w1_scale is not None - assert w2_scale is not None + assert quant_config.w1_scale is not None + assert quant_config.w2_scale is not None quant_method = QuantMethod.BLOCK_128x128.value - elif use_fp8_w8a8: + elif quant_config.use_fp8_w8a8: # Currently only per tensor quantization method is enabled. quant_method = QuantMethod.PER_TENSOR.value @@ -390,10 +390,10 @@ def rocm_aiter_fused_experts( expert_mask=expert_mask, quant_method=quant_method, activation_method=activation_method, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 6cd81d97f0298..b2dbc306a6148 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -7,7 +7,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used @@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, allow_deep_gemm: bool = False, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - self.triton_expert = TritonExperts( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ) + super().__init__(quant_config) - self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and + self.triton_expert = TritonExperts(quant_config) + + self.allow_deep_gemm = (allow_deep_gemm + and self.quant_config.use_fp8_w8a8 and self.block_shape == deep_gemm_block_shape()) self.deep_gemm_expert = DeepGemmExperts( - ) if self.allow_deep_gemm else None + self.quant_config) if self.allow_deep_gemm else None @property def activation_formats( @@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation, global_num_experts, expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, a1q_scale, - a2_scale, workspace13, workspace2, expert_tokens_meta, diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 14dfce4b0e3aa..8e5f6acc9df63 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -5,7 +5,8 @@ from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.utils import next_power_of_2 @@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, - w13_bias, - w2_bias, max_capture_size, ): - super().__init__(moe.quant_config) + super().__init__(quant_config) self.moe = moe self.gemm1_alpha = gemm1_alpha self.gemm1_beta = gemm1_beta self.gemm1_clamp_limit = gemm1_clamp_limit - self.w13_bias = w13_bias - self.w2_bias = w2_bias self.max_capture_size = max_capture_size @property @@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( torch.bfloat16).view(torch.int16) - assert w1_scale is not None - assert w2_scale is not None + assert self.w1_scale is not None + assert self.w2_scale is not None kwargs = { "topk_ids": packed_tensor, @@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): "gemm1_weights": w1, "gemm1_weights_scale": - w1_scale, + self.w1_scale, "gemm1_bias": - self.w13_bias, + self.w1_bias, "gemm1_alpha": self.gemm1_alpha, "gemm1_beta": @@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): "gemm2_weights": w2, "gemm2_weights_scale": - w2_scale, + self.w2_scale, "gemm2_bias": self.w2_bias, "output1_scale_scalar": diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 1aeb3f92bc3ea..678942e568d86 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -268,3 +268,7 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + + +def activation_without_mul(activation: str) -> str: + return activation + "_no_mul" diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index bf99f0823b745..060d6e84a944d 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -9,8 +9,10 @@ from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase): if hasattr(layer, "w2_bias") and layer.w2_bias is not None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 2245c59af6fea..650dab8df87e3 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union import torch from packaging import version +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): **extra_weight_attrs, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, + quant_config=self.moe_quant_config, ) def _create_weights_4bit( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5470deb768450..85adae32f4cdc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -16,8 +16,11 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, - FusedMoeWeightScaleSupported) + FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config, + int8_w8a16_moe_quant_config, nvfp4_moe_quant_config) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa @@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): return CompressedTensorsWNA16MarlinMoEMethod( quant_config, layer.moe_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer) + return CompressedTensorsW4A4MoeMethod(layer.moe_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)): @@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): - def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module): + def __init__(self, moe: FusedMoEConfig): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) super().__init__(moe) @@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - self.layer = layer def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): (layer.w2_input_global_scale), requires_grad=False) def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if not self.allow_flashinfer: - return super().maybe_make_prepare_finalize(moe) + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin: + return None + elif not self.allow_flashinfer: + return super().maybe_make_prepare_finalize() prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, - ) + self.moe) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None """Return the appropriate GEMM experts implementation.""" experts = select_nvfp4_gemm_impl( - moe, - g1_alphas=self.layer.g1_alphas, - g2_alphas=self.layer.g2_alphas, - a1_gscale=self.layer.w13_input_scale_quant, - a2_gscale=self.layer.w2_input_scale_quant, + self.moe, + self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + return nvfp4_moe_quant_config( + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + def apply( self, layer: torch.nn.Module, @@ -359,8 +370,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsW4A4MoeMethod` yet.") @@ -381,7 +390,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): indices_type=self.topk_indices_dtype, ) + # + # Note: the order here is important. self.fused_experts can override + # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin. + # if self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -401,8 +415,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): expert_map=expert_map, workspace=layer.workspace) - # FlashInfer fused experts path - if self.fused_experts is not None: + elif self.fused_experts is not None: assert is_valid_flashinfer_cutlass_fused_moe( x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") @@ -417,11 +430,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) + # FlashInfer fused experts path elif self.allow_flashinfer: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 flashinfer_cutlass_moe_fp4) @@ -430,51 +442,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") + assert self.moe_quant_config is not None + return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=self.moe_quant_config, inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, apply_router_weight_on_input=apply_router_weight_on_input, ) + else: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "CompressedTensorsW4A4MoeMethod.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "CompressedTensorsW4A4MoeMethod.") + assert self.moe_quant_config is not None - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4( - a=x, - w1_fp4=layer.w13_weight, - w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_weight_scale, - w2_blockscale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - apply_router_weight_on_input=apply_router_weight_on_input).to( - x.dtype) + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + # TODO(bnell): derive these from arguments + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + ).to(x.dtype) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): @@ -692,16 +699,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - self.fused_experts_func = None - else: - from vllm.model_executor.layers.fused_moe import fused_experts - self.fused_experts_func = fused_experts if self.use_cutlass: device = layer.w13_weight.device @@ -722,11 +724,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): device=device, dtype=torch.int64) + def maybe_make_prepare_finalize( + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin or self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + def select_gemm_impl( - self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, - layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute: + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> FusedMoEPermuteExpertsUnpermute: # cutlass path + assert self.moe_quant_config is not None if self.use_cutlass: from vllm.model_executor.layers.fused_moe import ( CutlassBatchedExpertsFp8, CutlassExpertsFp8) @@ -740,26 +751,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) experts = CutlassBatchedExpertsFp8( - moe.num_local_experts, + self.moe.num_local_experts, num_dispatchers, - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + self.moe.in_dtype, ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, + quant_config=self.moe_quant_config, ) else: logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) experts = CutlassExpertsFp8( - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + self.moe.in_dtype, ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, + quant_config=self.moe_quant_config, ) self.disable_expert_map = (num_dispatchers > 1 @@ -774,29 +783,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): assert not self.rocm_aiter_moe_enabled and not self.use_marlin - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank( ) assert max_num_tokens_per_rank is not None + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) return BatchedTritonExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN), + quant_config=self.moe_quant_config, ) else: - return TritonExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN), - ) + logger.debug("TritonExperts(%s)", self.__class__.__name__) + return TritonExperts(self.moe_quant_config) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + per_act_token = ( + self.input_quant.strategy == QuantizationStrategy.TOKEN) + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_channel_quant, + ) def apply( self, @@ -841,92 +861,19 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): indices_type=self.topk_indices_dtype, ) - # cutlass path - if self.use_cutlass: - per_act_token = ( - self.input_quant.strategy == QuantizationStrategy.TOKEN) - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + per_act_token = ( + self.input_quant.strategy == QuantizationStrategy.TOKEN) + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) - # small-batch fallback on SM100 - if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) - - if self.fused_experts is None: - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) - return cutlass_moe_fp8( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - per_act_token=per_act_token, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ab_strides1=self.ab_strides1_c_strides2, - ab_strides2=self.ab_strides2, - c_strides1=self.c_strides1, - c_strides2=self.ab_strides1_c_strides2, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - else: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - if self.rocm_aiter_moe_enabled: - return self.rocm_aiter_fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - expert_map=expert_map) + # + # Note: the order here is important. self.fused_experts can override + # cutlass fp8 or fused_experts but not marlin or rocm. + # if self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -944,26 +891,95 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map=expert_map, workspace=layer.workspace) - assert self.fused_experts_func is not None + elif self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts) + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + assert self.fused_experts is None + return rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) - return self.fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + elif self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + ) + + # cutlass path + elif self.use_cutlass: + assert self.moe_quant_config is not None + + # small-batch fallback on SM100 + if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: + from vllm.model_executor.layers.fused_moe import fused_experts + assert per_act_token == per_channel_quant + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + quant_config=self.moe_quant_config, + ) + else: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + return cutlass_moe_fp8( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_config=self.moe_quant_config, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, + ) + + else: + from vllm.model_executor.layers.fused_moe import fused_experts + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): @@ -1049,6 +1065,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return int8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=True, + ) + def apply( self, layer: torch.nn.Module, @@ -1104,14 +1130,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_int8_w8a8=True, - per_channel_quant=True, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + quant_config=self.moe_quant_config, + ) class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): @@ -1355,6 +1377,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer.workspace = marlin_make_workspace_new(device, 4) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -1588,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_scale.transpose(1, 2).contiguous(), requires_grad=False) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + assert self.num_bits == 4 or self.num_bits == 8 + config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4 + else int8_w8a16_moe_quant_config) + + return config_builder( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -1641,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_ids=topk_ids, inplace=True, activation=activation, - use_int4_w4a16=self.num_bits == 4, - use_int8_w8a16=self.num_bits == 8, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, self.group_size]) + quant_config=self.moe_quant_config, + ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index b361fe9bea088..8555e9ff20346 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -8,6 +8,8 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, int8_w8a16_moe_quant_config) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): requires_grad=False) layer.register_parameter("w2_scale", w2_scale) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + w1_zp=None, + w2_zp=None) + def apply( self, layer: torch.nn.Module, @@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_int8_w8a16=True, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale) + quant_config=self.moe_quant_config, + ) @staticmethod def quantizing_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 254cc2be05ee6..e75094c54743c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -14,9 +14,11 @@ from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, + FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -575,20 +577,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") - def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe, - layer=self.layer, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -928,10 +916,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w2_weight_scale_inv) + def maybe_make_prepare_finalize( + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if (self.rocm_aiter_moe_enabled or self.use_marlin + or self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM): + return None + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + prepare_finalize = ( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe)) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + else: + return super().maybe_make_prepare_finalize() + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( @@ -940,6 +941,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") + assert self.moe_quant_config is not None + if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): max_num_tokens_per_rank = ( @@ -953,15 +956,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=False, + quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: experts = select_cutlass_fp8_gemm_impl( - moe, - self.layer, + self.moe, + self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -971,11 +972,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.__class__.__name__, self.quant_config.weight_block_size, False) return TritonOrDeepGemmExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, + quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + return fp8_w8a8_moe_quant_config( + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) + def apply( self, layer: torch.nn.Module, @@ -1005,12 +1020,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and self.fused_experts is None): assert activation == 'silu', ( f"Expected 'silu' activation but got {activation}") assert scoring_func == 'sigmoid', ( f"Expected 'sigmoid' scoring func but got {scoring_func}") if self.block_quant: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 assert (renormalize and use_grouped_topk and custom_routing_function is None) @@ -1066,9 +1083,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): logical_replica_count=logical_replica_count, ) + # + # Note: the order of checks is important since self.fused_experts + # can override fused_experts or cutlass but not rocm or marlin. + # if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts) + assert self.fused_experts is None return rocm_aiter_fused_experts( x, layer.w13_weight, @@ -1076,19 +1098,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - use_fp8_w8a8=True, apply_router_weight_on_input=apply_router_weight_on_input, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - expert_map=expert_map) + expert_map=expert_map, + quant_config=self.moe_quant_config) elif self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -1105,40 +1121,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): global_num_experts=global_num_experts, expert_map=expert_map, workspace=layer.workspace) - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert self.block_quant is None - assert (not renormalize and custom_routing_function is not None) - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") - assert scoring_func == 'sigmoid', ( - f"Expected 'sigmoid' scoring func but got {scoring_func}") - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - return flashinfer_cutlass_moe_fp8( - x, - layer, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - common_kwargs = dict( + elif self.fused_experts: + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1149,26 +1133,43 @@ class Fp8MoEMethod(FusedMoEMethodBase): global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert self.block_quant is None + assert (not renormalize and custom_routing_function is not None) + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") - if self.fused_experts is not None: - return self.fused_experts(**common_kwargs) - else: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( - **common_kwargs, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm), - ) + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm)) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 01af1ccd9ae06..a631dfdab6544 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_qweight_type, extra_weight_attrs) layer.register_parameter("w2_qweight_type", w2_qweight_type) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 76de3a59c8ca1..e06b974255f01 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -9,8 +9,10 @@ import torch import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) @@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): if hasattr(layer, "w2_bias") and layer.w2_bias is not None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 5f9d4814274c8..c83b0b47a4b7e 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter from vllm._ipex_ops import ipex_ops as ops from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): use_prepack=True, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 60a79e53e8141..7eac40825ac33 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -11,7 +11,9 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.fused_moe.layer import ( @@ -294,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( @@ -303,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if self.fused_experts is not None or \ - self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe, - layer=self.layer, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize + self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + # TRT LLM not supported with all2all yet. + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + return None + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + prepare_finalize = ( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe)) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None experts = select_cutlass_fp8_gemm_impl( - moe, - self.layer, + self.moe, + self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -479,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + return None + + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=False, + ) + def apply( self, layer: torch.nn.Module, @@ -507,6 +518,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): "EPLB not supported for `ModelOptFp8MoEMethod` yet.") if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert self.fused_experts is None assert activation == 'silu', ( f"Expected 'silu' activation but got {activation}") assert not renormalize @@ -537,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): indices_type=self.topk_indices_dtype, ) - if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + # + # Note: the order here is important. self.fused_experts can override + # cutlass or fused_experts. + # + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not renormalize assert activation == 'silu', ( f"Expected 'silu' activation but got {activation}") - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - return flashinfer_cutlass_moe_fp8( - x, - layer, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - per_channel_quant=False, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts) + assert self.moe_quant_config is not None + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + quant_config=self.moe_quant_config, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) class ModelOptNvFp4Config(QuantizationConfig): @@ -1034,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): " for ModelOptNvFp4FusedMoE.") def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if (self.allow_flashinfer and self.flashinfer_moe_backend - == FlashinferMoeBackend.CUTLASS): + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if (self.use_marlin + or (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM)): + return None + elif (self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + # For now, fp4 moe only works with the flashinfer dispatcher. prepare_finalize = ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, - )) + build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize - - return super().maybe_make_prepare_finalize(moe) + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None experts = select_nvfp4_gemm_impl( - moe, - g1_alphas=self.layer.g1_alphas, - g2_alphas=self.layer.g2_alphas, - a1_gscale=self.layer.w13_input_scale_quant, - a2_gscale=self.layer.w2_input_scale_quant, + self.moe, + self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) @@ -1360,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if (self.use_marlin or self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM): + return None + + return nvfp4_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + ) + def apply( self, layer: torch.nn.Module, @@ -1388,12 +1413,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") assert activation == "silu", "Only SiLU activation is supported." - if self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM): import flashinfer from vllm.model_executor.models.llama4 import Llama4MoE + assert self.fused_experts is None + a1_gscale = layer.w13_input_scale_quant (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( @@ -1457,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) + # + # Note: the order here is important. self.fused_experts can override + # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or + # trtllm. + # if self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -1477,7 +1510,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_map=expert_map, workspace=layer.workspace) - if self.fused_experts is not None: + elif self.fused_experts is not None: assert self.allow_flashinfer and \ self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS @@ -1485,7 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") - out = self.fused_experts( + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1495,28 +1528,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) elif (self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 flashinfer_cutlass_moe_fp4) + assert self.moe_quant_config is not None - out = flashinfer_cutlass_moe_fp4( + return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - inplace=False, # TODO(shuw): fix later, now output is high prec + quant_config=self.moe_quant_config, + inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1527,23 +1554,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp4) - out = cutlass_moe_fp4( + assert self.moe_quant_config is not None + return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_weight_scale, - w2_blockscale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=self.moe_quant_config, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + # TODO: derive from arguments m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - - return out + ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index c25b3dd6080dc..145b614237fb3 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -6,6 +6,9 @@ from typing import Any, Callable, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, @@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase): layer.register_parameter(key, param) set_weight_attrs(param, extra_weight_attrs) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + assert weight_bits == 4 or weight_bits == 8 + config_builder = (int4_w4a16_moe_quant_config + if weight_bits == 4 else int8_w8a16_moe_quant_config) + + return config_builder( + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) - weight_bits = self.quant_config.weight_bits - has_zp = self.quant_config.has_zp - return fused_experts( x, layer.w13_qweight, @@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales, - w1_zp=layer.w13_qzeros if has_zp else None, - w2_zp=layer.w2_qzeros if has_zp else None, - block_shape=[0, layer.group_size]) + quant_config=self.moe_quant_config, + ) @staticmethod def get_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index f935bdd84124a..28c1e60ccd08a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -629,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return tile_tokens_dim + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + + if self.mxfp4_backend == Mxfp4Backend.MARLIN: + return None + + if self.mxfp4_backend == Mxfp4Backend.TRITON: + w1_scale = layer.w13_precision_config + w2_scale = layer.w2_precision_config + else: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: if (prepare_finalize.activation_format == @@ -647,11 +668,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): "gemm1_alpha": layer.gemm1_alpha, "gemm1_beta": layer.gemm1_beta, "gemm1_clamp_limit": layer.gemm1_clamp_limit, - "w13_bias": layer.w13_bias, - "w2_bias": layer.w2_bias, + # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } - return TrtLlmGenExperts(moe, **kwargs) + assert self.moe_quant_config is not None + return TrtLlmGenExperts(self.moe, self.moe_quant_config, + **kwargs) else: # Use matmul_ogs from triton_kernels here! raise NotImplementedError( @@ -710,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -941,10 +961,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): renormalize=renormalize, global_num_experts=global_num_experts, expert_map=expert_map, - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_precision=self.w13_precision_config, - w2_precision=self.w2_precision_config, + quant_config=self.moe_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, ) else: diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index bc8ae980429a3..d2d990e46bcf3 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -11,6 +11,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + mxfp4_w4a4_moe_quant_config) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -287,6 +290,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=self.weight_qscheme == "per_channel", + ) + def apply( self, layer: torch.nn.Module, @@ -339,12 +352,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_qscheme == "per_channel", - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + quant_config=self.moe_quant_config, expert_map=expert_map) if self.use_marlin: assert activation == "silu", ( @@ -376,14 +384,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_qscheme == "per_channel", global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + quant_config=self.moe_quant_config) class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): @@ -487,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return mxfp4_w4a4_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, + ) + def apply( self, layer: torch.nn.Module, @@ -539,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_mxfp4_w4a4=True, + activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, - activation=activation, + quant_config=self.moe_quant_config, ) return out diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 0d5fa05652b80..ed90e2e26460e 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase): fix_weights(layer, "w13_weight", weight_bits == 4) fix_weights(layer, "w2_weight", weight_bits == 4) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + weight_bits = self.quant_config.weight_bits + group_size = self.quant_config.group_size + assert weight_bits == 4 or weight_bits == 8 + config_builder = (int4_w4a16_moe_quant_config + if weight_bits == 4 else int8_w8a16_moe_quant_config) + return config_builder( + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) - weight_bits = self.quant_config.weight_bits - group_size = self.quant_config.group_size - - ret = fused_experts( + return fused_experts( x, layer.w13_weight, layer.w2_weight, @@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - block_shape=[0, group_size]) - - return ret + quant_config=self.moe_quant_config, + ) def rtn_quantize(tensor: torch.Tensor, num_bits: int, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index f5d7c57fe2a87..fabf855b36e68 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -7,7 +7,8 @@ import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 @@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, def build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig, - a1_gscale: torch.Tensor, -) -> mk.FusedMoEPrepareAndFinalize: + moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 - return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale) + return FlashInferCutlassMoEPrepareAndFinalize(use_dp) def select_nvfp4_gemm_impl( moe: FusedMoEConfig, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, + moe_quant_config: FusedMoEQuantConfig, allow_flashinfer: bool, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" if allow_flashinfer: return FlashInferExperts( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, out_dtype=moe.in_dtype, - quant_dtype="nvfp4", + quant_config=moe_quant_config, ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 9889808f0760f..aa66a42c588a7 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -8,7 +8,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 @@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8( apply_router_weight_on_input: bool, ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType + + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 assert layer.output1_scales_scalar is not None, ( "Expected output1_scales_scalar to be initialized") assert layer.output1_scales_scalar is not None, ( @@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: Optional[FusedMoEConfig], - layer: torch.nn.Module, -) -> mk.FusedMoEPrepareAndFinalize: + moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - return FlashInferCutlassMoEPrepareAndFinalize( - use_dp, a1_gscale=layer.w13_input_scale) + return FlashInferCutlassMoEPrepareAndFinalize(use_dp) def select_cutlass_fp8_gemm_impl( moe: Optional[FusedMoEConfig], - layer: torch.nn.Module, + quant_config: FusedMoEQuantConfig, out_dtype: Optional[torch.dtype] = None, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for fused-MoE layers""" - from vllm.model_executor.models.llama4 import Llama4MoE - assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ - "FusedMoE flashinfer kernels are only supported for Llama4" - if moe is not None: return FlashInferExperts( - g1_alphas=layer.output1_scales_gate_scalar, - g2_alphas=layer.output2_scales_scalar, - a1_gscale=layer.w13_input_scale, - a2_gscale=layer.w2_input_scale_inv, out_dtype=moe.in_dtype, - quant_dtype=torch.float8_e4m3fn, + quant_config=quant_config, ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, @@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl( assert out_dtype is not None, ( "If moe config is None, out_dtype must be passed") return FlashInferExperts( - g1_alphas=layer.output1_scales_gate_scalar, - g2_alphas=layer.output2_scales_scalar, - a1_gscale=layer.w13_input_scale, - a2_gscale=layer.w2_input_scale_inv, out_dtype=out_dtype, - quant_dtype=torch.float8_e4m3fn, + quant_config=quant_config, ) @@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8( expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: + quant_config = layer.quant_method.get_fused_moe_quant_config(layer) + assert quant_config is not None + fused_experts = mk.FusedMoEModularKernel( - build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, - layer=layer), + build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None), select_cutlass_fp8_gemm_impl(moe=None, - layer=layer, + quant_config=quant_config, out_dtype=hidden_states.dtype)) return fused_experts( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e3e9635132d68..bbe0c6f6d38ec 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -411,6 +411,7 @@ def per_token_group_quant_fp8( x_s = torch.empty(shape, device=x.device, dtype=torch.float32) # prefer CUDA kernel if available + # TODO(bnell): this causes some fp8 moe test to fail. if current_platform.is_cuda() and x.is_contiguous(): torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index b758cbf28d893..bfc1408ddf880 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import (get_act_and_mul_fn, get_act_fn) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, torch_vllm_outplace_fused_experts) +from vllm.model_executor.layers.fused_moe import (activation_without_mul, + fused_topk) from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -230,7 +230,7 @@ class NomicMoE(nn.Module): self.hidden_size = hidden_size self.total_intermediate_size = intermediate_size self.intermediate_size = divide(intermediate_size, self.tp_size) - self.hidden_act = hidden_act + self.hidden_act = activation_without_mul(hidden_act) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -297,14 +297,14 @@ class NomicMoE(nn.Module): router_logits, self.top_k, renormalize=False) - final_hidden_states = torch_vllm_outplace_fused_experts( + + final_hidden_states = torch.ops.vllm.outplace_fused_experts( hidden_states=hidden_states, w1=self.w1, w2=self.w2, topk_weights=topk_weights, topk_ids=topk_ids, activation=self.hidden_act, - is_act_and_mul=False, ) if self.tp_size > 1: diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 4395b11b7d0f0..59c9921881497 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module): shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True) + + topk_weights, topk_ids, _ = fused_topk( + hidden_states, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob) + + final_hidden_states = fused_experts(hidden_states, + self.w1, + self.w2, + topk_weights, + topk_ids, + inplace=True) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c7be7f76dba15..240c23ea2b25d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True) + + topk_weights, topk_ids, _ = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=True) + + final_hidden_states = fused_experts(hidden_states, + self.ws, + self.w2s, + topk_weights, + topk_ids, + inplace=True) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index f66e8b0b454bf..029309c49efd4 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -702,4 +702,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() + return self.model.get_expert_mapping() \ No newline at end of file diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index a25ef86a989db..a636a714145cf 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -81,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: - if not (isinstance(module, FusedMoE) - and module.moe_config.quant_dtype == torch.float8_e4m3fn - and module.moe_config.block_shape == deep_gemm_block_shape()): + if not isinstance(module, FusedMoE): + return False + + moe_quant_config = module.quant_method.get_fused_moe_quant_config(module) + + if (moe_quant_config is None + or moe_quant_config.quant_dtype != torch.float8_e4m3fn + or moe_quant_config.block_shape != deep_gemm_block_shape()): return False if not isinstance(module.quant_method.fused_experts,