From 8ad7285ea28ad3bcc898fa99812120bcda8ea7b4 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:46:00 -0400 Subject: [PATCH] [Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035) Signed-off-by: Bill Nell Co-authored-by: Michael Goin --- .buildkite/test-pipeline.yaml | 1 + docs/design/fused_moe_modular_kernel.md | 10 +- examples/offline_inference/data_parallel.py | 23 +- .../moe/modular_kernel_tools/common.py | 532 +++++++++--------- .../moe/modular_kernel_tools/mk_objects.py | 461 ++++++++++++++- .../profile_modular_kernel.py | 4 +- .../kernels/moe/modular_kernel_tools/utils.py | 117 ---- tests/kernels/moe/test_batched_moe.py | 4 +- tests/kernels/moe/test_block_fp8.py | 31 +- tests/kernels/moe/test_block_int8.py | 15 +- .../kernels/moe/test_cutlass_grouped_gemm.py | 17 +- tests/kernels/moe/test_deepep_deepgemm_moe.py | 6 +- tests/kernels/moe/test_deepgemm.py | 6 +- tests/kernels/moe/test_flashinfer_moe.py | 147 +++++ .../moe/test_modular_kernel_combinations.py | 129 +++-- tests/kernels/moe/test_nvfp4_moe.py | 60 +- tests/kernels/moe/test_pplx_cutlass_moe.py | 11 +- tests/kernels/moe/test_pplx_moe.py | 4 +- tests/kernels/moe/utils.py | 75 ++- .../base_device_communicator.py | 7 +- .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/batched_deep_gemm_moe.py | 36 +- .../batched_triton_or_deep_gemm_moe.py | 38 +- .../model_executor/layers/fused_moe/config.py | 11 +- .../layers/fused_moe/cutlass_moe.py | 326 ++++++----- .../layers/fused_moe/deep_gemm_moe.py | 3 +- .../fused_moe/deepep_ht_prepare_finalize.py | 30 +- .../fused_moe/deepep_ll_prepare_finalize.py | 32 +- .../fused_moe/flashinfer_cutlass_moe.py | 59 +- .../flashinfer_cutlass_prepare_finalize.py | 52 +- .../layers/fused_moe/fused_batched_moe.py | 98 ++-- .../layers/fused_moe/fused_moe.py | 7 +- .../fused_moe/gpt_oss_triton_kernels_moe.py | 15 +- vllm/model_executor/layers/fused_moe/layer.py | 93 +-- .../layers/fused_moe/modular_kernel.py | 117 ++-- .../layers/fused_moe/pplx_prepare_finalize.py | 33 +- .../layers/fused_moe/prepare_finalize.py | 43 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 37 +- vllm/model_executor/layers/fused_moe/utils.py | 18 +- .../layers/quantization/auto_round.py | 4 +- .../model_executor/layers/quantization/awq.py | 2 +- .../layers/quantization/awq_marlin.py | 18 +- .../layers/quantization/bitsandbytes.py | 12 +- .../compressed_tensors_moe.py | 168 ++++-- .../layers/quantization/experts_int8.py | 17 +- .../model_executor/layers/quantization/fp8.py | 43 +- .../layers/quantization/gguf.py | 15 +- .../layers/quantization/gptq_marlin.py | 14 +- .../layers/quantization/modelopt.py | 99 ++-- .../layers/quantization/moe_wna16.py | 16 +- .../layers/quantization/mxfp4.py | 2 +- .../layers/quantization/quark/quark_moe.py | 39 +- .../model_executor/layers/quantization/rtn.py | 13 +- .../quantization/utils/flashinfer_fp4_moe.py | 129 +---- 54 files changed, 2010 insertions(+), 1293 deletions(-) delete mode 100644 tests/kernels/moe/modular_kernel_tools/utils.py create mode 100644 tests/kernels/moe/test_flashinfer_moe.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 04d7cdc3d885..87296a08e207 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -399,6 +399,7 @@ steps: - label: Kernels MoE Test %N mirror_hardwares: [amdexperimental] source_file_dependencies: + - csrc/quantization/cutlass_w8a8/moe/ - csrc/moe/ - tests/kernels/moe - vllm/model_executor/layers/fused_moe/ diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 3ef1232051b0..4b917ab408ee 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking ### FusedMoEModularKernel Initialization -`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, +`FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, +* maybe_make_prepare_finalize, * select_gemm_impl, and * init_prepare_finalize +#### maybe_make_prepare_finalize + +The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. +Please refer to the implementations in, + +* `ModelOptNvFp4FusedMoE` + #### select_gemm_impl The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index dbf8ed58cc47..dd7559451c4c 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -70,12 +70,27 @@ def parse_args(): default=64, help=("Maximum number of sequences to be processed in a single iteration."), ) + parser.add_argument( + "--max-model-len", + type=int, + help=("Maximum number of tokens to be processed in a single iteration."), + ) + parser.add_argument( + "--timeout", + type=int, + default=300, + help=("Number of seconds before unresponsive process is killed."), + ) parser.add_argument( "--gpu-memory-utilization", type=float, default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--quantization", + type=str, + ) return parser.parse_args() @@ -90,7 +105,9 @@ def main( enforce_eager, trust_remote_code, max_num_seqs, + max_model_len, gpu_memory_utilization, + quantization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) @@ -142,7 +159,9 @@ def main( enable_expert_parallel=True, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, + max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, + quantization=quantization, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -198,14 +217,16 @@ if __name__ == "__main__": args.enforce_eager, args.trust_remote_code, args.max_num_seqs, + args.max_model_len, args.gpu_memory_utilization, + args.quantization, ), ) proc.start() procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=300) + proc.join(timeout=args.timeout) if proc.exitcode is None: print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") proc.kill() diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index fd99e8dc5c98..a10666b6ec9a 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -7,41 +7,22 @@ import torch import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8 +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from tests.kernels.utils import torch_experts from vllm.config import VllmConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size -# Fused experts and PrepareFinalize imports -from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, - TritonExperts) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from .mk_objects import (expert_info, make_fused_experts, + make_prepare_finalize, prepare_finalize_info) from .parallel_utils import ProcessGroupInfo -from .utils import (make_block_quant_fp8_weights, make_non_quant_weights, - make_quant_fp8_weights, per_token_cast_to_fp8) - -if has_pplx(): - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) -if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: @@ -69,24 +50,31 @@ class Config: torch_trace_dir_path: Optional[str] = None + def __post_init__(self): + if self.quant_config is None: + self.quant_config = FusedMoEQuantConfig() + def describe(self) -> str: s = "" - s += "== Config: \n" - s += f" world_size={self.world_size} \n" - s += f" PF={self.prepare_finalize_type.__name__} \n" - s += f" FE={self.fused_experts_type.__name__} \n" - s += f" topk={self.topks} \n" - s += f" dtype={self.dtype} \n" - s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n" - s += " Quant: \n" - s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n " + s += "== Config:\n" + s += f" world_size={self.world_size}\n" + s += f" PF={self.prepare_finalize_type.__name__}\n" + s += f" FE={self.fused_experts_type.__name__}\n" + s += f" E={self.E}\n" + s += f" Ms={self.Ms}\n" + s += f" N={self.N}\n" + s += f" K={self.K}\n" + s += f" topk={self.topks}\n" + s += f" dtype={self.dtype}\n" + s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n" + s += " Quant:\n" if self.quant_config is not None: - s += f" q_dtype={self.quant_dtype} \n" - s += f" q_block_shape={self.quant_block_shape} \n" - s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n" - s += f" q_per_act_token={self.is_per_act_token_quant} \n" + s += f" q_dtype={self.quant_dtype}\n" + s += f" q_block_shape={self.quant_block_shape}\n" + s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n" + s += f" q_per_act_token={self.is_per_act_token_quant}\n" else: - s += " quant=None \n" + s += " quant=None\n" return s @property @@ -95,34 +83,28 @@ class Config: return self.Ms @property - def quant_dtype(self) -> Optional[torch.dtype]: - if self.quant_config is None: - return None + def quant_dtype(self) -> Union[torch.dtype, str, None]: + assert self.quant_config is not None return self.quant_config.quant_dtype @property def is_per_act_token_quant(self) -> bool: - if self.quant_config is None: - return False + assert self.quant_config is not None return self.quant_config.per_act_token_quant @property def is_per_tensor_act_quant(self) -> bool: - if self.quant_config is None: - return False return (not self.is_per_act_token_quant and self.quant_block_shape is None) @property def is_per_out_ch_quant(self) -> bool: - if self.quant_config is None: - return False + assert self.quant_config is not None return self.quant_config.per_out_ch_quant @property def quant_block_shape(self) -> Optional[list[int]]: - if self.quant_config is None: - return None + assert self.quant_config is not None return self.quant_config.block_shape @property @@ -130,36 +112,30 @@ class Config: assert isinstance(self.topks, int) return self.topks - @property - def topk_ids_dtype(self) -> Optional[torch.dtype]: - topk_ids_dtype = None - if self.prepare_finalize_type == PplxPrepareAndFinalize: - topk_ids_dtype = torch.uint32 - elif self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ]: - topk_ids_dtype = torch.int64 - return topk_ids_dtype - @property def num_local_experts(self) -> int: return self.E // self.world_size def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: """ - make env data for vllm launch. + make env data for vllm launch. """ vllm_config = VllmConfig() vllm_config.parallel_config.data_parallel_size = self.world_size vllm_config.parallel_config.enable_expert_parallel = True env_dict = { - "VLLM_ALL2ALL_BACKEND": self.all2all_backend(), "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())), } + + backend = self.all2all_backend() + if backend is not None: + env_dict.update({"VLLM_ALL2ALL_BACKEND": backend}) + if self.fused_moe_chunk_size is not None: env_dict.update( {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + return vllm_config, env_dict def is_fp8_block_quantized(self): @@ -167,85 +143,59 @@ class Config: and self.quant_block_shape is not None) def is_batched_prepare_finalize(self): - return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return (mk.FusedMoEActivationFormat.BatchedExperts == + info.activation_format) def is_batched_fused_experts(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, - NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts - ] + info = expert_info(self.fused_experts_type) + return (mk.FusedMoEActivationFormat.BatchedExperts == + info.activation_format) def is_standard_fused_experts(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts - ] + info = expert_info(self.fused_experts_type) + return mk.FusedMoEActivationFormat.Standard == info.activation_format - def is_fe_16bit_supported(self): - return self.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - NaiveBatchedExperts, TritonExperts - ] + def fe_supported_types(self): + info = expert_info(self.fused_experts_type) + return info.supported_dtypes - def is_fe_fp8_supported(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - BatchedTritonExperts, - BatchedTritonOrDeepGemmExperts, - CutlassExpertsFp8, - DeepGemmExperts, - TritonExperts, - TritonOrDeepGemmExperts, - NaiveBatchedExperts, - ] + def pf_supported_types(self): + info = prepare_finalize_info(self.prepare_finalize_type) + return info.supported_dtypes - def is_fe_block_fp8_supported(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - BatchedTritonOrDeepGemmExperts, - DeepGemmExperts, - TritonExperts, - TritonOrDeepGemmExperts, - BatchedTritonExperts, - NaiveBatchedExperts, - ] + def is_block_quant_supported(self): + info = expert_info(self.fused_experts_type) + return info.blocked_quantization_support def is_fe_supports_chunking(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts - ] + info = expert_info(self.fused_experts_type) + return info.supports_chunking + + def supports_expert_map(self): + info = expert_info(self.fused_experts_type) + return info.supports_expert_map + + def supports_apply_weight_on_input(self): + info = prepare_finalize_info(self.prepare_finalize_type) + return info.supports_apply_weight_on_input def needs_deep_gemm(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - DeepGemmExperts, - ] + info = expert_info(self.fused_experts_type) + return info.needs_deep_gemm def needs_pplx(self): - return self.prepare_finalize_type in [PplxPrepareAndFinalize] + info = prepare_finalize_info(self.prepare_finalize_type) + return info.backend == "pplx" def needs_deep_ep(self): - return self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return (info.backend == "deepep_high_throughput" + or info.backend == "deepep_low_latency") def all2all_backend(self): - if self.needs_pplx(): - return "pplx" - if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize: - return "deepep_high_throughput" - if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize: - return "deepep_low_latency" - return "naive" - - def needs_all2all(self): - return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize, - DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return info.backend def is_valid(self): # Check prepare-finalize and fused-experts compatibility @@ -267,28 +217,28 @@ class Config: # invalid quant config return False - # check bf16 / fp16 support - is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) - if is_16bit and not self.is_fe_16bit_supported(): - return False + # check type support + if self.quant_dtype is None: + if (self.dtype not in self.pf_supported_types() + or self.dtype not in self.fe_supported_types()): + return False + else: + if (self.quant_dtype not in self.pf_supported_types() + or self.quant_dtype not in self.fe_supported_types()): + return False - # Check fp8 support - is_fp8 = self.quant_dtype == torch.float8_e4m3fn - if is_fp8 and not self.is_fe_fp8_supported(): - return False - - # Check fp8 block quanization support + # Check block quanization support is_block_quatized = self.quant_block_shape is not None - if is_block_quatized and not is_fp8: + if is_block_quatized and self.quant_dtype is None: return False - if is_block_quatized and not self.is_fe_block_fp8_supported(): + if is_block_quatized and not self.is_block_quant_supported(): return False # deep_gemm only works with block-quantized if self.needs_deep_gemm() and not is_block_quatized: return False - # Check dependencies + # Check dependencies (turn into asserts?) if self.needs_deep_ep() and not has_deep_ep(): return False if self.needs_deep_gemm() and not has_deep_gemm(): @@ -305,6 +255,8 @@ class WeightTensors: w2: torch.Tensor w1_scale: Optional[torch.Tensor] w2_scale: Optional[torch.Tensor] + w1_gs: Optional[torch.Tensor] = None + w2_gs: Optional[torch.Tensor] = None def describe(self): s = "" @@ -313,13 +265,20 @@ class WeightTensors: s += f' - {_describe_tensor(self.w2, "w2")} \n' s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' + s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n' + s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n' return s + def is_quantized(self) -> bool: + # or w1_scale is not None? + return (self.w1.dtype == torch.float8_e4m3fn + or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) + def to_current_device(self): self.w1 = self.w1.to(device=torch.cuda.current_device()) self.w2 = self.w2.to(device=torch.cuda.current_device()) - is_quantized = self.w1.dtype == torch.float8_e4m3fn - if is_quantized: + + if self.is_quantized(): assert self.w1_scale is not None assert self.w2_scale is not None self.w1_scale = self.w1_scale.to( @@ -327,56 +286,51 @@ class WeightTensors: self.w2_scale = self.w2_scale.to( device=torch.cuda.current_device()) + if self.w1_gs is not None: + assert self.w2_gs is not None + self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device()) + self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device()) + def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": s = rank * num_local_experts e = s + num_local_experts w1 = self.w1[s:e, :, :] w2 = self.w2[s:e, :, :] - is_quantized = self.w1.dtype == torch.float8_e4m3fn + w1_scale, w2_scale = (None, None) - if is_quantized: + if self.is_quantized(): assert self.w1_scale is not None assert self.w2_scale is not None w1_scale = self.w1_scale[s:e, :, :] w2_scale = self.w2_scale[s:e, :, :] - return WeightTensors(w1, w2, w1_scale, w2_scale) + + w1_gs = self.w1_gs + w2_gs = self.w2_gs + if w1_gs is not None: + assert w2_gs is not None + w1_gs = w1_gs[s:e] + w2_gs = w2_gs[s:e] + + return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs) @staticmethod def make(config: Config) -> "WeightTensors": - - if config.quant_dtype is None: - # just make normal dtype weights - w1, w2 = make_non_quant_weights(e=config.E, - n=config.N, - k=config.K, - dtype=config.dtype) - return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None) - - assert config.quant_dtype == torch.float8_e4m3fn - if not config.is_fp8_block_quantized(): - w1, w2, w1_scale, w2_scale = make_quant_fp8_weights( - e=config.E, - n=config.N, - k=config.K, - per_out_channel_quant=config.is_per_out_ch_quant, - ) - return WeightTensors(w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale) - - assert config.quant_block_shape is not None - w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( + (_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights( e=config.E, n=config.N, k=config.K, - block_size=config.quant_block_shape, + in_dtype=config.dtype, + quant_dtype=config.quant_dtype, + block_shape=config.quant_block_shape, + per_act_token_quant=config.is_per_out_ch_quant, ) return WeightTensors(w1=w1, w2=w2, w1_scale=w1_scale, - w2_scale=w2_scale) + w2_scale=w2_scale, + w1_gs=w1_gs, + w2_gs=w2_gs) @dataclass @@ -449,7 +403,6 @@ class RankTensors: dtype=dtype) topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False) - topk_ids = topk_ids.to(config.topk_ids_dtype) # distribute topk_ids evenly for mi in range(m): @@ -457,7 +410,7 @@ class RankTensors: topk_ids = topk_ids.to(device=torch.cuda.current_device()) expert_map = None - if config.world_size > 1: + if config.world_size > 1 and config.supports_expert_map(): expert_map = torch.full((global_num_experts, ), fill_value=-1, dtype=torch.int32) @@ -480,92 +433,100 @@ class RankTensors: def reference_moe_impl(config: Config, weights: WeightTensors, rank_tensors: RankTensors) -> torch.Tensor: - return torch_experts(a=rank_tensors.hidden_states, - w1=weights.w1, - w2=weights.w2, + if config.quant_dtype == "nvfp4": + quant_blocksize = 16 + dtype = config.dtype + + w1_q = weights.w1 + w1_blockscale = weights.w1_scale + w1_gs = weights.w1_gs + + w2_q = weights.w2 + w2_blockscale = weights.w2_scale + w2_gs = weights.w2_gs + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax( + rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32) + + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + + assert w1_blockscale.shape[1] % 128 == 0 + assert w1_blockscale.shape[2] % 4 == 0 + assert w2_blockscale.shape[1] % 128 == 0 + assert w2_blockscale.shape[2] % 4 == 0 + + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant( + rank_tensors.hidden_states, a_global_scale) + + a = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=dtype, + device=a_fp4.device, + block_size=quant_blocksize) + + e = w1_q.shape[0] + n = w1_q.shape[1] // 2 + k = w2_q.shape[1] + + w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype) + w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) + a_scale = None + w1_scale = None + w2_scale = None + quant_dtype = None + per_act_token_quant = False + block_shape = None + else: + a = rank_tensors.hidden_states + a_scale = rank_tensors.hidden_states_scale + w1 = weights.w1 + w1_scale = weights.w1_scale + w2 = weights.w2 + w2_scale = weights.w2_scale + quant_dtype = config.quant_dtype + per_act_token_quant = config.is_per_act_token_quant + block_shape = config.quant_block_shape + + return torch_experts(a=a, + w1=w1, + w2=w2, topk_weight=rank_tensors.topk_weights, topk_ids=rank_tensors.topk_ids, global_num_experts=config.E, expert_map=None, - w1_scale=weights.w1_scale, - w2_scale=weights.w2_scale, - a1_scale=rank_tensors.hidden_states_scale, - quant_dtype=config.quant_dtype, - per_act_token_quant=config.is_per_act_token_quant, - block_shape=config.quant_block_shape, - apply_router_weights_on_input=config.topk == 1) + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + apply_router_weights_on_input=config.topk == 1 + and config.supports_apply_weight_on_input()) -def make_fused_experts( - config: Config, moe: FusedMoEConfig, - num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute: - - use_fp8 = config.quant_dtype == torch.float8_e4m3fn - batch_kwargs = { - "max_num_tokens": moe.max_num_tokens, - "num_dispatchers": num_dispatchers, - } - quant_kwargs = { - "use_fp8_w8a8": use_fp8, - "use_int8_w8a8": False, - "use_int8_w8a16": False, - "use_int4_w4a16": False, - "block_shape": config.quant_block_shape, - "per_act_token_quant": config.is_per_act_token_quant, - } - deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} - - if config.fused_experts_type == BatchedDeepGemmExperts: - kwargs = batch_kwargs | { - "block_shape": config.quant_block_shape, - "per_act_token_quant": config.is_per_act_token_quant, - } - print(f"Making BatchedDeepGemmExperts {kwargs} ...") - experts = BatchedDeepGemmExperts(**kwargs) - elif config.fused_experts_type == BatchedTritonExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making BatchedTritonExperts {kwargs} ...") - experts = BatchedTritonExperts(**kwargs) - elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts: - kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs - print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") - experts = BatchedTritonOrDeepGemmExperts(**kwargs) - elif config.fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts () ...") - experts = DeepGemmExperts() - elif config.fused_experts_type == TritonExperts: - kwargs = quant_kwargs - print(f"Making TritonExperts {kwargs} ...") - experts = TritonExperts(**kwargs) - elif config.fused_experts_type == TritonOrDeepGemmExperts: - kwargs = quant_kwargs | deepgemm_kwargs - print(f"Making TritonOrDeepGemmExperts {kwargs} ...") - experts = TritonOrDeepGemmExperts(**kwargs) - elif config.fused_experts_type == NaiveBatchedExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making NaiveBatchedExperts {kwargs} ...") - experts = NaiveBatchedExperts(**kwargs) - elif config.fused_experts_type == CutlassExpertsFp8: - use_batched_format = config.is_batched_prepare_finalize() - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) - kwargs = { - "max_experts_per_worker": num_experts, - "out_dtype": moe.in_dtype, - "per_act_token_quant": config.is_per_act_token_quant, - "per_out_ch_quant": config.is_per_out_ch_quant, - "block_shape": config.quant_block_shape, - "num_dispatchers": num_dispatchers, - "use_batched_format": use_batched_format - } - print(f"Making CutlassExpertsFp8 {kwargs} ...") - experts = CutlassExpertsFp8(**kwargs) - - return experts - - -def make_modular_kernel(config: Config, - vllm_config: VllmConfig) -> mk.FusedMoEModularKernel: +def make_modular_kernel( + config: Config, + vllm_config: VllmConfig, + weights: WeightTensors, +) -> mk.FusedMoEModularKernel: def next_power_of_2(x): import math @@ -579,6 +540,7 @@ def make_modular_kernel(config: Config, dp_size_=get_dp_group().world_size, vllm_parallel_config=vllm_config.parallel_config, ) + moe = FusedMoEConfig( num_experts=config.E, experts_per_token=config.topk, @@ -591,15 +553,16 @@ def make_modular_kernel(config: Config, ) # make modular kernel - prepare_finalize = None - if config.needs_all2all(): - prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe) - assert prepare_finalize is not None - else: - prepare_finalize = MoEPrepareAndFinalizeNoEP() + prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, + config.all2all_backend(), moe) - fused_experts = make_fused_experts(config, moe, - prepare_finalize.num_dispatchers()) + fused_experts = make_fused_experts( + config.fused_experts_type, + moe, + prepare_finalize.num_dispatchers(), + weights.w1_gs, + weights.w2_gs, + ) modular_kernel = mk.FusedMoEModularKernel( prepare_finalize=prepare_finalize, fused_experts=fused_experts) @@ -620,22 +583,45 @@ def run_modular_kernel( # weights for rank rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) - mk = make_modular_kernel(config, vllm_config) + mk = make_modular_kernel(config, vllm_config, weights) mk_kwargs = { - "hidden_states": rank_tensors.hidden_states.clone( + "hidden_states": + rank_tensors.hidden_states.clone( ), # impls might update the tensor in place - "w1": rank_weights.w1, - "w2": rank_weights.w2, - "topk_weights": rank_tensors.topk_weights, - "topk_ids": rank_tensors.topk_ids, - "expert_map": rank_tensors.expert_map, - "w1_scale": rank_weights.w1_scale, - "w2_scale": rank_weights.w2_scale, - "a1_scale": rank_tensors.hidden_states_scale, - "global_num_experts": config.E, - "apply_router_weight_on_input": config.topk == 1, + "w1": + rank_weights.w1, + "w2": + rank_weights.w2, + "topk_weights": + rank_tensors.topk_weights, + "topk_ids": + rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()), + "expert_map": + rank_tensors.expert_map, + "w1_scale": + rank_weights.w1_scale, + "w2_scale": + rank_weights.w2_scale, + "a1_scale": + rank_tensors.hidden_states_scale, + "global_num_experts": + config.E, + "apply_router_weight_on_input": + config.topk == 1 and config.supports_apply_weight_on_input(), } - out = mk.forward(**mk_kwargs) + + num_tokens = rank_tensors.hidden_states.shape[0] + num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size, + device="cuda", + dtype=torch.int) + + with set_forward_context( + None, + vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ): + out = mk.forward(**mk_kwargs) return out diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 73214066f7ea..aecffae36ae5 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -1,58 +1,316 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional, Union import torch # Fused experts and PrepareFinalize imports +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.layer import TritonExperts +from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, + TritonExperts) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) -from vllm.utils import has_deep_ep, has_pplx +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_fp8_supported) +from vllm.platforms import current_platform +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if has_deep_ep(): + +@dataclass +class PrepareFinalizeInfo: + activation_format: mk.FusedMoEActivationFormat + supported_dtypes: list[Union[torch.dtype, str]] + blocked_quantization_support: bool + backend: Optional[str] + supports_apply_weight_on_input: bool = True + + +@dataclass +class ExpertInfo: + activation_format: mk.FusedMoEActivationFormat + supported_dtypes: list[Union[torch.dtype, str]] + blocked_quantization_support: bool + supports_chunking: bool + supports_expert_map: bool + needs_matching_quant: bool = False + needs_deep_gemm: bool = False + + +PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, + PrepareFinalizeInfo] = {} +EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {} +MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = [] + +standard_format = mk.FusedMoEActivationFormat.Standard +batched_format = mk.FusedMoEActivationFormat.BatchedExperts +common_float_types: list[Union[torch.dtype, str]] = [ + torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 +] +common_float_and_int_types = common_float_types + [torch.int8] +nv_fp4_types = ["nvfp4"] +fp8_types = [torch.float8_e4m3fn] + + +def register_prepare_and_finalize( + kind, + activation_format: mk.FusedMoEActivationFormat, + supported_dtypes: list[Union[torch.dtype, str]], + blocked_quantization_support: bool, + backend: Optional[str], + force_multigpu: bool = False, + supports_apply_weight_on_input: bool = True, +): + global PREPARE_FINALIZE_INFO + global MK_ALL_PREPARE_FINALIZE_TYPES + global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES + assert kind not in PREPARE_FINALIZE_INFO + + PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo( + activation_format, + supported_dtypes, + blocked_quantization_support, + backend, + supports_apply_weight_on_input, + ) + MK_ALL_PREPARE_FINALIZE_TYPES.append(kind) + if backend is not None or force_multigpu: + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind) + else: + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind) + + +def register_experts( + kind, + activation_format: mk.FusedMoEActivationFormat, + supported_dtypes: list[Union[torch.dtype, str]], + blocked_quantization_support: bool, + supports_chunking: bool, + supports_expert_map: bool, + needs_matching_quant: bool = False, + needs_deep_gemm: bool = False, +): + global EXPERT_INFO + global MK_FUSED_EXPERT_TYPES + assert kind not in EXPERT_INFO + + EXPERT_INFO[kind] = ExpertInfo( + activation_format, + supported_dtypes, + blocked_quantization_support, + supports_chunking, + supports_expert_map, + needs_matching_quant, + needs_deep_gemm, + ) + + MK_FUSED_EXPERT_TYPES.append(kind) + + +def prepare_finalize_info(kind) -> PrepareFinalizeInfo: + info = PREPARE_FINALIZE_INFO.get(kind) + assert info is not None + return info + + +def expert_info(kind) -> ExpertInfo: + info = EXPERT_INFO.get(kind) + assert info is not None + return info + + +register_prepare_and_finalize( + MoEPrepareAndFinalizeNoEP, + standard_format, + common_float_types, + blocked_quantization_support=True, + backend=None, +) + +register_experts( + BatchedTritonExperts, + batched_format, + common_float_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=True, +) + +register_experts( + TritonExperts, + standard_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=True, +) + +register_experts( + NaiveBatchedExperts, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=True, +) + +# Disable on blackwell for now +if has_deep_ep() and not current_platform.has_device_capability(100): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) + register_prepare_and_finalize( + DeepEPHTPrepareAndFinalize, + standard_format, + common_float_types, + blocked_quantization_support=True, + backend="deepep_high_throughput", + ) + + register_prepare_and_finalize( + DeepEPLLPrepareAndFinalize, + batched_format, + common_float_types, + blocked_quantization_support=True, + backend="deepep_low_latency", + ) + if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) + register_prepare_and_finalize( + PplxPrepareAndFinalize, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + backend="pplx", + ) -MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] -if has_pplx(): - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] -if has_deep_ep(): - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] +if (has_flashinfer_cutlass_fused_moe() + and current_platform.has_device_capability(100)): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) -MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] + register_prepare_and_finalize( + FlashInferCutlassMoEPrepareAndFinalize, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + backend=None, + force_multigpu=True, + supports_apply_weight_on_input=False, + ) -MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + - MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + register_experts( + FlashInferExperts, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + supports_chunking=True, + # Note: this is a hack to get it to run for now + supports_expert_map=True, + ) +else: + FlashInferCutlassMoEPrepareAndFinalize = None -MK_FUSED_EXPERT_TYPES = [ - BatchedDeepGemmExperts, - BatchedTritonExperts, - NaiveBatchedExperts, - BatchedTritonOrDeepGemmExperts, - CutlassExpertsFp8, - DeepGemmExperts, - TritonOrDeepGemmExperts, - TritonExperts, -] +if has_deep_gemm() and is_deep_gemm_supported(): + register_experts( + BatchedDeepGemmExperts, + batched_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=False, + needs_deep_gemm=True, + ) + register_experts( + DeepGemmExperts, + standard_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=False, + needs_deep_gemm=True, + ), + register_experts( + BatchedTritonOrDeepGemmExperts, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=True, + needs_deep_gemm=True, + ) + register_experts( + TritonOrDeepGemmExperts, + standard_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=True, + needs_deep_gemm=True, + ) + +if cutlass_fp8_supported(): + from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8, + CutlassExpertsFp8) + register_experts( + CutlassExpertsFp8, + standard_format, + fp8_types, + blocked_quantization_support=False, + supports_chunking=True, + supports_expert_map=False, + ) + register_experts( + CutlassBatchedExpertsFp8, + batched_format, + fp8_types, + blocked_quantization_support=False, + supports_chunking=False, + supports_expert_map=False, + ) + +if cutlass_fp4_supported(): + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4) + register_experts( + CutlassExpertsFp4, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=False, + ) MK_QUANT_CONFIGS = [ None, @@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [ # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations ] + +if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): + MK_QUANT_CONFIGS += [ + FusedMoEQuantConfig(quant_dtype="nvfp4", + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), + ] + + +def _make_gscale(num_experts: int) -> torch.Tensor: + return torch.ones((num_experts, ), + device=torch.cuda.current_device(), + dtype=torch.float32) + + +def make_prepare_finalize( + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + backend: Optional[str], + moe: FusedMoEConfig, +) -> mk.FusedMoEPrepareAndFinalize: + if backend != "naive" and backend is not None: + prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + assert prepare_finalize is not None + return prepare_finalize + elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: + return FlashInferCutlassMoEPrepareAndFinalize( + use_dp=moe.moe_parallel_config.dp_size > 1, + a1_gscale=_make_gscale(moe.num_local_experts), + ) + else: + return MoEPrepareAndFinalizeNoEP() + + +def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: + s = rank * num_local_experts + e = s + num_local_experts + return t[s:e] + + +def make_fused_experts( + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + moe: FusedMoEConfig, + num_dispatchers: int, + w1_gs: Optional[torch.Tensor], + w2_gs: Optional[torch.Tensor], +) -> mk.FusedMoEPermuteExpertsUnpermute: + + use_fp8 = moe.quant_dtype == torch.float8_e4m3fn + batch_kwargs = { + "max_num_tokens": moe.max_num_tokens, + "num_dispatchers": num_dispatchers, + } + quant_kwargs = { + "use_fp8_w8a8": use_fp8, + "use_int8_w8a8": False, + "use_int8_w8a16": False, + "use_int4_w4a16": False, + "block_shape": moe.block_shape, + "per_act_token_quant": moe.per_act_token_quant, + } + deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + + if fused_experts_type == BatchedDeepGemmExperts: + kwargs = batch_kwargs | { + "block_shape": moe.block_shape, + "per_act_token_quant": moe.per_act_token_quant, + } + print(f"Making BatchedDeepGemmExperts {kwargs} ...") + experts = BatchedDeepGemmExperts(**kwargs) + elif fused_experts_type == BatchedTritonExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making BatchedTritonExperts {kwargs} ...") + experts = BatchedTritonExperts(**kwargs) + elif fused_experts_type == BatchedTritonOrDeepGemmExperts: + kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs + print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") + experts = BatchedTritonOrDeepGemmExperts(**kwargs) + elif fused_experts_type == DeepGemmExperts: + print("Making DeepGemmExperts () ...") + experts = DeepGemmExperts() + elif fused_experts_type == TritonExperts: + kwargs = quant_kwargs + print(f"Making TritonExperts {kwargs} ...") + experts = TritonExperts(**kwargs) + elif fused_experts_type == TritonOrDeepGemmExperts: + kwargs = quant_kwargs | deepgemm_kwargs + print(f"Making TritonOrDeepGemmExperts {kwargs} ...") + experts = TritonOrDeepGemmExperts(**kwargs) + elif fused_experts_type == NaiveBatchedExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making NaiveBatchedExperts {kwargs} ...") + experts = NaiveBatchedExperts(**kwargs) + elif fused_experts_type == CutlassExpertsFp8: + kwargs = { + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + } + print(f"Making CutlassExpertsFp8 {kwargs} ...") + experts = CutlassExpertsFp8(**kwargs) + elif fused_experts_type == CutlassBatchedExpertsFp8: + kwargs = { + "max_experts_per_worker": moe.num_local_experts, + "num_dispatchers": num_dispatchers, + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + } + print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") + experts = CutlassBatchedExpertsFp8(**kwargs) + elif fused_experts_type == CutlassExpertsFp4: + assert w1_gs is not None and w2_gs is not None + num_experts = moe.num_local_experts + rank = moe.moe_parallel_config.dp_rank + kwargs = { + "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), + "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), + "a1_gscale": _make_gscale(num_experts), + "a2_gscale": _make_gscale(num_experts), + "max_experts_per_worker": num_experts, + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + "num_dispatchers": num_dispatchers, + } + print(f"Making CutlassExpertsFp4 {kwargs} ...") + experts = CutlassExpertsFp4(**kwargs) + elif fused_experts_type == FlashInferExperts: + assert w1_gs is not None and w2_gs is not None + num_experts = moe.num_local_experts + rank = moe.moe_parallel_config.dp_rank + kwargs = { + "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), + "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), + "a1_gscale": _make_gscale(num_experts), + "a2_gscale": _make_gscale(num_experts), + "out_dtype": moe.in_dtype, + "quant_dtype": "nvfp4", + "ep_rank": moe.ep_rank, + "ep_size": moe.ep_size, + "tp_rank": moe.tp_rank, + "tp_size": moe.tp_size, + } + print(f"Making FlashInferExperts {kwargs} ...") + experts = FlashInferExperts(**kwargs) + else: + raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") + + return experts diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index dd16ffb2eabe..0da6ee354352 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -52,7 +52,7 @@ def profile_modular_kernel( rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) # make modular kernel - mk = make_modular_kernel(config, vllm_config) + mk = make_modular_kernel(config, vllm_config, weights) mk_kwargs = { "hidden_states": rank_tensors.hidden_states, @@ -83,7 +83,7 @@ def rank_worker( # sanity check from vllm import envs if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() diff --git a/tests/kernels/moe/modular_kernel_tools/utils.py b/tests/kernels/moe/modular_kernel_tools/utils.py deleted file mode 100644 index 866f52882bee..000000000000 --- a/tests/kernels/moe/modular_kernel_tools/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm._custom_ops as ops -from vllm.utils.deep_gemm import per_block_cast_to_fp8 - - -def per_token_cast_to_fp8( - x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (block_size - (n % block_size)) % block_size - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, block_size) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - -def make_non_quant_weights( - e: int, - n: int, - k: int, - dtype: torch.dtype, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Return weights w1, w2 - """ - device = torch.cuda.current_device() - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15 - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15 - return w1, w2 - - -def make_block_quant_fp8_weights( - e: int, - n: int, - k: int, - block_size: list[int], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Return weights w1, w2, w1_scale, w2_scale - """ - dtype = torch.bfloat16 - device = torch.cuda.current_device() - - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype) - w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * n) + block_n - 1) // block_n - k_tiles_w1 = (k + block_k - 1) // block_k - n_tiles_w2 = (k + block_n - 1) // block_n - k_tiles_w2 = (n + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device) - - w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), - device=device, - dtype=torch.float32) - w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), - device=device, - dtype=torch.float32) - - assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n, - (k + (block_k - 1)) // block_k) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], - block_size=[block_k, block_n]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], - block_size=[block_k, block_n]) - - return w1, w2, w1_s, w2_s - - -def make_quant_fp8_weights( - e: int, - n: int, - k: int, - per_out_channel_quant: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Return w1, w2, w1_scale, w2_scale - """ - q_dtype = torch.float8_e4m3fn - - w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16) - - # w1 -> w1_q, w2 -> w2_q - w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) - w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) - - n_b_scales = 2 * n if per_out_channel_quant else 1 - k_b_scales = k if per_out_channel_quant else 1 - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_channel_quant) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_channel_quant) - return w1_q, w2_q, w1_scale, w2_scale diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index edf3e6189243..00b2d780e66f 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, per_act_token_quant=per_act_token_quant, ) - B, B_q, B_scale, _, _, _ = make_test_weights( + (B, B_q, B_scale, _), _ = make_test_weights( num_experts, N // 2, K, @@ -243,7 +243,7 @@ def test_fused_moe_batched_experts( act_dtype = dtype quant_dtype = None - w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights( + (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights( e, n, k, diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 75b2e9f79178..9e4eaf221f24 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, + use_mxfp4_w4a4=False, per_act_token_quant=False, block_shape=block_size) @@ -247,13 +249,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 8e680c722935..5e4a93963f8e 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.int8, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.int8, + per_act_token_quant=False, + block_shape=block_size) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 1aee1ed8c376..3b1618dacac7 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -9,6 +9,7 @@ import random import pytest import torch +from tests.kernels.moe.utils import per_token_cast_to_fp8 from tests.kernels.utils import baseline_scaled_mm from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -16,20 +17,6 @@ from vllm.utils import cdiv from vllm.utils.deep_gemm import per_block_cast_to_fp8 -def per_token_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (128 - (n % 128)) % 128 - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * - (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ (4, 8192, 7168, 4096), (4, 8192, 2048, 7168), @@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm( device=device, dtype=torch.float)) for i in range(num_groups): - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128]) for i in range(num_groups): a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 9b064db973dd..6f95581a5e60 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -70,8 +70,10 @@ def make_block_quant_fp8_weights( """ Return weights w1q, w2q, w1_scale, w2_scale """ - w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( - e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) + (_, w1q, w1_scale, _), (_, w2q, w2_scale, + _) = make_test_weights(e, n, k, torch.bfloat16, + torch.float8_e4m3fn, + block_size) return w1q, w2q, w1_scale, w2_scale diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index b2b78662c9de..4472f34a6291 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size): # Note: W1 has shape (E, 2N, K), so N = 512 # can trigger the deepgemm path. MNKs = [ - (1024, 512, 128), - (1024, 512, 512), - (2048, 512, 512), + (1024, 768, 128), + (1024, 768, 512), + (2048, 768, 512), (512, 1024, 1024), (512, 2048, 2048), (4096, 4096, 1024), diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py new file mode 100644 index 000000000000..1c14df2b914a --- /dev/null +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.moe.utils import make_test_weights +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from tests.kernels.utils import torch_moe +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if not has_flashinfer_cutlass_fused_moe( +) or not current_platform.has_device_capability(100): + pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True) + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1536), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +#@pytest.mark.parametrize("e", [128, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + quant_blocksize = 16 + + (_, w1_q, w1_blockscale, + w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_act_token_quant=False, + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, + score, + topk, + renormalize=False) + + a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + + assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + + flashinfer_experts = FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + FlashInferExperts( + a1_gscale=a1_gs, + g1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + g2_alphas=(1 / w2_gs), + out_dtype=dtype, + quant_dtype="nvfp4", + )) + + flashinfer_output = flashinfer_experts( + hidden_states=a, + w1=w1_q, + w1_scale=w1_blockscale, + w2=w2_q, + w2_scale=w2_blockscale, + a1_scale=a1_gs, + a2_scale=a2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + + # Reference check: + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) + + torch.testing.assert_close(torch_output, + flashinfer_output, + atol=1e-1, + rtol=1e-1) + + +if __name__ == "__main__": + test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 6f2869c3a61d..d45982384eb3 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import textwrap +import traceback from itertools import product from typing import Optional @@ -10,41 +12,51 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, current_platform, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.layer import TritonExperts -from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, reference_moe_impl, run_modular_kernel) from .modular_kernel_tools.mk_objects import ( MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info) from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, parallel_launch_with_config) -# TODO (varun): These requirements are very strict and could be relaxed. -has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) +has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx() + or has_flashinfer_cutlass_fused_moe()) -meets_package_requirements = pytest.mark.skipif( - not has_all_packages, - reason="Requires deep_ep & deep_gemm & pplx packages", +meets_multi_gpu_requirements = pytest.mark.skipif( + not has_any_multi_gpu_package, + reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages", ) +def format_result(verbose, msg, ex=None): + if ex is not None: + x = str(ex) + newx = x.strip(" \n\t")[:16] + if len(newx) < len(x): + newx = newx + " ..." + + prefix = "E\t" + print(f"{textwrap.indent(traceback.format_exc(), prefix)}") + print(f"FAILED {msg} - {newx}\n") + elif verbose: + print(f"PASSED {msg}") + else: + print(".", end="") + + def rank_worker( pgi: ProcessGroupInfo, vllm_config: VllmConfig, cpu_group, config: Config, weights: WeightTensors, + verbose: bool, ): current_platform.seed_everything(pgi.rank) @@ -61,39 +73,64 @@ def rank_worker( TOPKs = config.topks assert isinstance(TOPKs, list) + exceptions = [] + count = 0 + for m, topk in product(Ms, TOPKs): - print(f"Running m={m}, topk={topk} ...") - # override m and topk - cfgx = copy.deepcopy(config) - cfgx.Ms = m - cfgx.topks = topk + try: + print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...") + count = count + 1 + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk - # inputs for rank - rank_tensors = RankTensors.make(cfgx, pgi) + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) - # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + # modular kernel out + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + rank_tensors) - with set_current_vllm_config(vllm_config): - ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + with set_current_vllm_config(vllm_config): + ref_out = reference_moe_impl(cfgx, weights, rank_tensors) - torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) + if config.quant_dtype == "nvfp4": + atol = 1e-1 + rtol = 1e-1 + else: + atol = 3e-2 + rtol = 3e-2 + + torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol) + format_result(verbose, config.describe()) + except Exception as ex: + format_result(verbose, config.describe(), ex) + exceptions.append(ex) + + if len(exceptions) > 0: + raise RuntimeError( + f"{len(exceptions)} of {count} tests failed in child process, " + f"rank={pgi.rank}.") + else: + print(f"{count} of {count} tests passed in child process, " + f"rank={pgi.rank}.") -def run(config: Config): +def run(config: Config, verbose: bool): assert config.is_valid() - print(f"Testing config \n{config.describe()} ...") weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights) + env_dict, config, weights, verbose) Ms = [32, 64] -Ks = [7168] # hidden sizes +# hidden sizes, making this too large will cause fp4 tests to fail. +# Also needs to be a multiple of 1024 for deep_gemm. +Ks = [2048] Ns = [2048] TOPKs = [4, 1] Es = [32] @@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16] def is_nyi_config(config: Config) -> bool: # We know these configs to be legitimate. but still fail. + info = expert_info(config.fused_experts_type) - if (config.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - TritonExperts, TritonOrDeepGemmExperts - ]): + if info.needs_matching_quant: # The triton kernels expect both per-act-token-quant and # per-out-ch-quant or neither. unsupported_quant_config = ((config.is_per_act_token_quant + config.is_per_out_ch_quant) == 1) return unsupported_quant_config - # cutlass kernels dont support expert_maps yet. - return config.fused_experts_type == CutlassExpertsFp8 + return not info.supports_expert_map @pytest.mark.parametrize("k", Ks) @@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool: product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [2]) -@meets_package_requirements +@meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, + quant_config: Optional[FusedMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): + fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): config = Config( Ms=Ms, @@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu( fused_moe_chunk_size=fused_moe_chunk_size, world_size=world_size, ) + if not config.is_valid(): pytest.skip(f"Tests config {config} is not valid. Skipping ...") if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - print(f"{config.describe()}") - run(config) + verbosity = pytestconfig.getoption('verbose') + run(config, verbosity > 0) @pytest.mark.parametrize("k", Ks) @@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu( product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [1]) -@meets_package_requirements def test_modular_kernel_combinations_singlegpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, + quant_config: Optional[FusedMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): + fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): config = Config( Ms=Ms, K=k, @@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu( if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - run(config) + verbosity = pytestconfig.getoption('verbose') + run(config, verbosity > 0) if __name__ == '__main__': @@ -211,4 +246,4 @@ if __name__ == '__main__': args = parser.parse_args() config = make_config(args) - run(config) + run(config, True) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 3ff385360299..30388ef9375d 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -3,6 +3,7 @@ import pytest import torch +from tests.kernels.moe.utils import make_test_weights from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype) @@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 - round_up = lambda x, y: (x + y - 1) // y * y - sf_w1_2n = round_up(2 * n, 128) - sf_w1_k = round_up(k // quant_blocksize, 4) - w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - sf_w2_k = round_up(k, 128) - sf_w2_n = round_up(n // quant_blocksize, 4) - w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), - device="cuda", - dtype=torch.float8_e4m3fn) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1_q = torch.empty((e, 2 * n, k // 2), - device="cuda", - dtype=torch.uint8) - w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) - w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) - w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) - - for expert in range(e): - w1_amax = torch.abs(w1).max().to(torch.float32) - w2_amax = torch.abs(w2).max().to(torch.float32) - w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax - w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax - - w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( - w1[expert], w1_gs[expert]) - - w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( - w2[expert], w2_gs[expert]) + (_, w1_q, w1_blockscale, + w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_act_token_quant=False, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, @@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + cutlass_output = cutlass_moe_fp4( a=a, a1_gscale=a1_gs, @@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, n=n, k=k, e=e, - device=a.device, ) # Reference check: a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) - _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, a_scale_interleaved, a_global_scale, @@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], w1_blockscale[idx], w1_gs[idx], - dtype=w1.dtype, - device=w1.device, + dtype=dtype, + device=w1_q.device, block_size=quant_blocksize) w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], w2_blockscale[idx], w2_gs[idx], - dtype=w2.dtype, - device=w2.device, + dtype=dtype, + device=w2_q.device, block_size=quant_blocksize) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index e4f4a393dfd5..f98937ee6c52 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,7 +9,8 @@ import torch from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassBatchedExpertsFp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -123,12 +124,8 @@ def pplx_cutlass_moe( num_local_experts=num_local_experts, num_dispatchers=num_dispatchers) - experts = CutlassExpertsFp8(num_local_experts, - out_dtype, - per_act_token, - per_out_ch, - num_dispatchers=num_dispatchers, - use_batched_format=True) + experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, + out_dtype, per_act_token, per_out_ch) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index fbef6706beaf..c2064de97358 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -770,7 +770,7 @@ def test_pplx_moe_slow( a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - _, w1, w1_s, _, w2, w2_s = make_test_weights( + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( e, n, k, @@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, args = dict() if make_weights: - _, w1, w1_s, _, w2, w2_s = make_test_weights( + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( e, n, k, diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index c33134981acc..82960bd57345 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import per_block_cast_to_int8 +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX) from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) @@ -169,28 +171,41 @@ def make_quantized_test_activations( def moe_quantize_weights( w: torch.Tensor, w_s: Optional[torch.Tensor], - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_token_quant: bool, block_shape: Optional[list[int]], -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert (quant_dtype == torch.float8_e4m3fn - or quant_dtype == torch.int8), "only fp8/int8 supported" +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8 + or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported" + + w_gs = None if block_shape is not None: assert not per_token_quant if quant_dtype == torch.int8: w, w_s = per_block_cast_to_int8(w, block_shape) - else: + elif quant_dtype == torch.float8_e4m3fn: w, w_s = per_block_cast_to_fp8(w, block_shape) + elif quant_dtype == "nvfp4": + raise RuntimeError("blocked quantization not supported for nvfp4") + else: + raise RuntimeError(f"Unsupported quant type {quant_dtype}") else: if quant_dtype == torch.int8: w, w_s = ops.scaled_int8_quant( w, w_s, use_per_token_if_dynamic=per_token_quant) - else: + elif quant_dtype == torch.float8_e4m3fn: w, w_s = ops.scaled_fp8_quant( w, w_s, use_per_token_if_dynamic=per_token_quant) + elif quant_dtype == "nvfp4": + assert not per_token_quant + w_amax = torch.abs(w).max().to(torch.float32) + w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax + w, w_s = ops.scaled_fp4_quant(w, w_gs) + else: + raise RuntimeError(f"Unsupported quant type {quant_dtype}") - return w, w_s + return w, w_s, w_gs def make_test_weight( @@ -198,21 +213,26 @@ def make_test_weight( rows: int, cols: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 + w_gs = None if quant_dtype is not None: w_l = [None] * e w_s_l = [None] * e + w_gs_l = [None] * e for idx in range(e): - w_l[idx], w_s_l[idx] = moe_quantize_weights( + w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) w = torch.stack(w_l) w_s = torch.stack(w_s_l) + if e > 0 and w_gs_l[0] is not None: + w_gs = torch.stack(w_gs_l) if w_s.ndim == 2: assert w_s.shape[-1] == 1 w_s = w_s.view(-1, 1, 1) @@ -225,8 +245,9 @@ def make_test_weight( else: w = w_16 w_s = None + w_gs = None - return w_16, w, w_s + return w_16, w, w_s, w_gs def make_test_weights( @@ -234,14 +255,30 @@ def make_test_weights( n: int, k: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, - torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]], + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]]: return ( - *make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_act_token_quant), - *make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, + per_act_token_quant), + make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, + per_act_token_quant), ) + + +def per_token_cast_to_fp8( + x: torch.Tensor, + block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (block_size - (n % block_size)) % block_size + x = torch.nn.functional.pad(x, + (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, block_size) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 127a340fc6c6..9e5aa4e4c2a8 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -105,7 +105,8 @@ class DeviceCommunicatorBase: # we initialize the all2all manager used in expert parallel. use_ep = config.parallel_config.data_parallel_size > 1 - self.use_all2all = "ep" in unique_name and use_ep + self.is_ep_communicator = "ep" in unique_name + self.use_all2all = self.is_ep_communicator and use_ep self.all2all_manager: Optional[All2AllManagerBase] = None def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: @@ -246,7 +247,7 @@ class DeviceCommunicatorBase: """ Prepare the communication buffer for the model. """ - if not self.use_all2all: + if not self.is_ep_communicator: return moe_modules = [ @@ -254,7 +255,7 @@ class DeviceCommunicatorBase: if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - module.quant_method.init_prepare_finalize(module.moe_config) + module.quant_method.init_prepare_finalize() def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3d40879b4ccb..3007643d7a28 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -49,7 +49,8 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) + CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4, + cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( @@ -69,6 +70,7 @@ if HAS_TRITON: "cutlass_moe_fp8", "cutlass_moe_fp4", "CutlassExpertsFp8", + "CutlassBatchedExpertsFp8", "TritonExperts", "BatchedTritonExperts", "DeepGemmExperts", diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index c48a0137c306..d9cfe96f7a03 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -254,18 +254,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index fc30e84e6656..89d7412ee223 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -132,18 +132,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): a, aq, M, N, K, topk, global_num_experts, local_num_experts, expert_tokens_metadata) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None @@ -151,4 +161,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, workspace2, expert_tokens_meta, - apply_router_weight_on_input, extra_expert_args) + apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 31ea826f1f97..7c1a7b636a9c 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -45,7 +45,6 @@ def get_quant_config_weight_quant( return _get_quant_config_quantization_args(quant_config, "weights") -# TODO (bnell): use scalar_type instead of bools? def get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, @@ -65,7 +64,8 @@ def get_config_quant_dtype( @dataclass class FusedMoEQuantConfig: # The post quantization activation type. - quant_dtype: Optional[torch.dtype] = None + # TODO (bnell): use scalar_type instead of Union. + quant_dtype: Union[torch.dtype, str, None] = None per_act_token_quant: bool = False per_out_ch_quant: bool = False block_shape: Optional[list[int]] = None @@ -141,6 +141,7 @@ class FusedMoEQuantConfig: use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + use_mxfp4_w4a4, ] ]) <= 1, "Quantization flags are mutually exclusive." @@ -334,7 +335,7 @@ class FusedMoEConfig: assert self.max_num_tokens > 0 @property - def quant_dtype(self) -> Optional[torch.dtype]: + def quant_dtype(self) -> Union[torch.dtype, str, None]: if self.quant_config is not None: return self.quant_config.quant_dtype else: @@ -429,7 +430,7 @@ class FusedMoEConfig: block_shape = None per_act_token_quant = False per_out_ch_quant = False - quant_dtype: Optional[torch.dtype] = None + quant_dtype: Union[torch.dtype, str, None] = None input_quant = get_quant_config_input_quant(quant_config) weight_quant = get_quant_config_weight_quant(quant_config) @@ -453,7 +454,7 @@ class FusedMoEConfig: ModelOptNvFp4Config) if quant_dtype is None and isinstance(quant_config, ModelOptNvFp4Config): - quant_dtype = torch.uint8 + quant_dtype = "nvfp4" if weight_quant is not None: per_out_ch_quant = ( diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 2585a2953c9d..0a02b558d09e 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ CUTLASS based Fused MoE kernels.""" -from typing import Any, Callable, Optional +from typing import Callable, Optional import torch @@ -12,11 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, _fp8_quantize, - _resize_cache, - extract_required_args) + _resize_cache) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -213,19 +212,14 @@ def run_cutlass_moe_fp8( output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) -# TODO (bnell): split class batched vs. non-batched? -# maybe remove need for passing aq to workspace_shapes -class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_experts_per_worker: int, out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, block_shape: Optional[list[int]] = None, - num_dispatchers: Optional[int] = None, - use_batched_format: bool = False, ): super().__init__( FusedMoEQuantConfig( @@ -234,33 +228,84 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, )) - assert max_experts_per_worker > 0 - assert not use_batched_format or num_dispatchers is not None - self.max_experts_per_worker = max_experts_per_worker - self.num_dispatchers = num_dispatchers self.out_dtype = out_dtype - self.use_batched_format = use_batched_format + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + + expert_num_tokens = None + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + + activation_callable = lambda o, i: self.activation(activation, o, i) + + use_batched_format = self.activation_formats[ + 0] == mk.FusedMoEActivationFormat.BatchedExperts + + in_dtype = hidden_states.dtype + run_cutlass_moe_fp8( + output, hidden_states, w1, w2, topk_ids, activation_callable, + global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, workspace13, workspace2, expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, self.per_out_ch_quant, + use_batched_format) + + +class CutlassExpertsFp8(CutlassExpertsFp8Base): + + def __init__( + self, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, + ): + super().__init__( + out_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape, + ) @property def activation_formats( self ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.use_batched_format: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) - else: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: - return not self.use_batched_format + return True def supports_expert_map(self) -> bool: - return not self.use_batched_format - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return True def workspace_shapes( self, @@ -274,54 +319,69 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - workspace1: tuple[int, ...] = () - workspace2: tuple[int, ...] = () - output: tuple[int, ...] = () - if self.use_batched_format: - padded_M = aq.size(1) - num_dp = self.num_dispatchers - assert num_dp is not None - workspace1 = (self.max_experts_per_worker, padded_M * num_dp, - max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, - (N // 2)) - output = (self.max_experts_per_worker, padded_M, K) - else: - workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk, N // 2) - output = (M * topk, K) + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk, N // 2) + output = (M * topk, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): - assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" - assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" - expert_num_tokens = None - if expert_tokens_meta is not None: - expert_num_tokens = expert_tokens_meta.expert_num_tokens +class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - activation_callable = lambda o, i: self.activation(activation, o, i) + def __init__( + self, + max_experts_per_worker: int, + num_dispatchers: int, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, + ): + super().__init__( + out_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape, + ) + assert max_experts_per_worker > 0 + self.max_experts_per_worker = max_experts_per_worker + self.num_dispatchers = num_dispatchers - in_dtype = hidden_states.dtype - run_cutlass_moe_fp8( - output, hidden_states, w1, w2, topk_ids, activation_callable, - global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, expert_num_tokens, - self.out_dtype if self.out_dtype is not None else in_dtype, - self.per_act_token_quant, self.per_out_ch_quant, - self.use_batched_format) + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + # TODO(bnell): maybe remove need for passing aq to workspace_shapes + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + padded_M = aq.size(1) + num_dp = self.num_dispatchers + assert num_dp is not None + workspace1 = (self.max_experts_per_worker, padded_M * num_dp, + max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2)) + output = (self.max_experts_per_worker, padded_M, K) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) def cutlass_moe_fp8( @@ -387,11 +447,9 @@ def cutlass_moe_fp8( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - max_experts_per_worker=num_experts, out_dtype=a.dtype, per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, - use_batched_format=False, ), ) @@ -476,8 +534,9 @@ def run_cutlass_moe_fp4( e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape - assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", - " between weights.") + assert (e_w1 == e_w2 + and e_w1 == e), ("Number of experts must match", + f" between weights. {e_w1}, {e_w2}, {e}") assert (k_a == half_k_w1 * 2 and k == k_w2), ("Hidden size mismatch between a, w1 and w2") assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " @@ -554,6 +613,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, max_experts_per_worker: int, out_dtype: torch.dtype, per_act_token_quant: bool, @@ -562,8 +625,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): use_batched_format: bool = False, ): super().__init__( + # NVFP4 requires two levels of quantization, which involves + # computing some scaling factors dynamically. This makes it + # incompatible with the typical prepare -> MoE -> finalize + # pipeline. Move the quantization logic into the MoE body. FusedMoEQuantConfig( - quant_dtype=torch.uint8, + quant_dtype=None, # skip quantization in prepare/finalize per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, @@ -572,6 +639,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): self.out_dtype = out_dtype self.use_batched_format = use_batched_format + # TODO(bnell): put this stuff into quant config? + self.g1_alphas = g1_alphas + self.g2_alphas = g2_alphas + self.a1_gscale = a1_gscale + self.a2_gscale = a2_gscale + @property def activation_formats( self @@ -590,8 +663,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -620,34 +692,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, - w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], - workspace2: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): - required_keys = [ - "g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k", - "e", "device" - ] - (g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e, - device) = extract_required_args(extra_expert_args, required_keys) + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: torch.Tensor, + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids) + n = w2.shape[2] * 2 + run_cutlass_moe_fp4( output=output, a=hidden_states, - a1_gscale=a1_gscale, + a1_gscale=self.a1_gscale, w1_fp4=w1, w1_blockscale=w1_scale, - w1_alphas=g1_alphas, - a2_gscale=a2_gscale, + w1_alphas=self.g1_alphas, + a2_gscale=self.a2_gscale, w2_fp4=w2, w2_blockscale=w2_scale, - w2_alphas=g2_alphas, + w2_alphas=self.g2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, workspace13=workspace13, @@ -656,7 +736,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): n=n, k=k, e=e, - device=device, + device=hidden_states.device, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -677,7 +757,6 @@ def cutlass_moe_fp4( n: int, k: int, e: int, - device: torch.device, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False) -> torch.Tensor: assert expert_map is None, ("Expert Parallelism / expert_map " @@ -686,6 +765,10 @@ def cutlass_moe_fp4( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( + g1_alphas, + g2_alphas, + a1_gscale, + a2_gscale, max_experts_per_worker=e, out_dtype=a.dtype, per_act_token_quant=False, @@ -693,29 +776,7 @@ def cutlass_moe_fp4( use_batched_format=False, ), ) - extra_expert_args = { - 'g1_alphas': g1_alphas, - 'g2_alphas': g2_alphas, - 'a1_gscale': a1_gscale, - 'a2_gscale': a2_gscale, - 'm': m, - 'n': n, - 'k': k, - 'e': e, - 'device': device, - } - # NVFP4 requires two levels of quantization, which involves computing some - # scaling factors dynamically. This makes it incompatible with the typical - # prepare -> MoE -> finalize pipeline. Move the quantization logic into the - # MoE body. - extra_prepare_args = { - 'skip_quant': True, - } - # Similar reason as above. - extra_finalize_args = { - 'skip_weight_reduce': True, - } return fn( hidden_states=a, w1=w1_fp4, @@ -731,9 +792,6 @@ def cutlass_moe_fp4( a1_scale=None, a2_scale=None, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args, - extra_prepare_args=extra_prepare_args, - extra_finalize_args=extra_finalize_args, ) @@ -824,16 +882,6 @@ def run_cutlass_block_scaled_fused_experts( k = w1_q.size(1) n = w2_q.size(1) - expert_offsets = torch.empty((num_experts + 1, ), - dtype=torch.int32, - device="cuda") - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device="cuda") - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device="cuda") - topk = topk_ids.size(1) a_q, a1_scale = _fp8_quantize(a, @@ -842,6 +890,16 @@ def run_cutlass_block_scaled_fused_experts( block_shape=[128, 128]) device = a_q.device + expert_offsets = torch.empty((num_experts + 1, ), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 9b8175f42a9d..7b8467a5a0cf 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Any, Optional +from typing import Optional import torch from tqdm import tqdm @@ -230,7 +230,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): assert self.block_shape is not None assert a1q_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index f6b62254e7b4..437e569d3130 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import deep_ep import torch @@ -127,12 +127,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_topk_weights) def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -187,11 +191,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert self.handle is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index cfc2bdcf0240..93ac11fb4bfb 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import Optional, Union import deep_ep import torch @@ -77,7 +77,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -111,12 +111,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return x, x_scales def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -162,11 +166,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return (expert_x, expert_x_scale, expert_tokens_meta, None, None) - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 4e3e15a35ada..3fbe2a0bc69b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional, Union import torch @@ -8,8 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import extract_required_args + TopKWeightAndReduceNoOP) from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe) @@ -20,7 +19,7 @@ def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> bool: """ - Check if the given problem size is supported by the FlashInfer CUTLASS MoE + Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel. """ if not has_flashinfer_cutlass_fused_moe(): @@ -43,31 +42,34 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_nvfp4_w4a4: bool = False, - use_fp8_w8a8: bool = False, - use_dp: bool = False, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + out_dtype: torch.dtype, + quant_dtype: Union[torch.dtype, str, None], ep_rank: int = 0, ep_size: int = 1, tp_rank: int = 0, tp_size: int = 1, - num_dispatchers: Optional[int] = None, - use_batched_format: bool = False, ): super().__init__( FusedMoEQuantConfig( - quant_dtype=torch.uint8, + quant_dtype=quant_dtype, per_act_token_quant=False, block_shape=None, )) - self.use_nvfp4_w4a4 = use_nvfp4_w4a4 - self.use_fp8_w8a8 = use_fp8_w8a8 + assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is " + "currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank self.tp_size = tp_size - self.use_dp = use_dp - assert not use_batched_format or num_dispatchers is not None - self.num_dispatchers = num_dispatchers + self.g1_alphas = g1_alphas + self.g2_alphas = g2_alphas + self.a1_gscale = a1_gscale + self.a2_gscale = a2_gscale + self.out_dtype = out_dtype @property def activation_formats( @@ -84,8 +86,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -117,8 +118,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " - "currently supported.") aq_m, aq_n = aq.shape workspace2 = () output_shape = (aq_m, aq_n * 2) @@ -149,21 +148,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], - extra_expert_args: Optional[dict[str, Any]], ): - assert extra_expert_args is not None, \ - "extra_expert_args must be provided" - required_keys = [ - 'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype' - ] - - g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = ( - extract_required_args(extra_expert_args, required_keys)) - # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. - assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " - "currently supported.") # Ensure w1_scale and w2_scale are not None before calling view assert w1_scale is not None and w2_scale is not None, ( @@ -171,12 +158,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): "be None for FlashInferExperts") quant_scales = [ - a1_gscale, + self.a1_gscale, w1_scale.view(torch.int32), - g1_alphas, - a2_gscale, + self.g1_alphas, + self.a2_gscale, w2_scale.view(torch.int32), - g2_alphas, + self.g2_alphas, ] _ = flashinfer_cutlass_fused_moe( input=hidden_states, @@ -185,7 +172,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): # FlashInfer API requires weight to be long for nvfp4 fc1_expert_weights=w1.view(torch.long), fc2_expert_weights=w2.view(torch.long), - output_dtype=out_dtype, + output_dtype=self.out_dtype, quant_scales=quant_scales, input_sf=a1q_scale, tp_size=self.tp_size, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 36aca8cf74b6..061b02172c44 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - extract_required_args, moe_kernel_quantize_input) + moe_kernel_quantize_input) from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -21,16 +21,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, - quant_dtype: Optional[torch.dtype] = None, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, + use_dp: bool, + a1_gscale: Optional[torch.Tensor], num_dispatchers: int = 1, ): super().__init__() - self.per_channel_quant = per_channel_quant - self.block_shape = block_shape - self.quant_dtype = quant_dtype self.num_dispatchers_ = num_dispatchers + self.use_dp = use_dp + self.a1_gscale = a1_gscale + self.local_tokens = None @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -55,10 +54,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -67,22 +67,22 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - (a1_gscale, use_dp, local_tokens) = extract_required_args( - extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens']) - a1q, a1q_scale = moe_kernel_quantize_input( a1, - a1_gscale, + self.a1_gscale, quant_config.quant_dtype, - self.per_channel_quant, - self.block_shape, - is_fp4_scale_swizzled=not use_dp, # Swizzling after communication + quant_config.per_act_token_quant, + quant_config.block_shape, + # Swizzling after communication + is_fp4_scale_swizzled=not self.use_dp, ) - if use_dp: + if self.use_dp: topk_weights, topk_ids, a1q, a1q_scale = \ - get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 - dim=0, - sizes=get_local_sizes()) + get_dp_group().all_gatherv( + [topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes(), + ) a1_m, a1_n = a1q.shape a1q_scale = nvfp4_block_scale_interleave(a1q_scale) @@ -91,13 +91,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: - (use_dp, - local_tokens) = extract_required_args(extra_finalize_args, - ['use_dp', 'local_tokens']) - if use_dp: + if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, dim=0, sizes=get_local_sizes()) output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9a5c85e120cc..b46f4be4b912 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import Any, Optional +from typing import Optional import torch @@ -496,12 +496,16 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.num_dispatchers_ def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -590,11 +594,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return b_a1, b_a1_scale, expert_tokens_meta, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) weight_and_reduce_impl.apply( @@ -688,18 +696,28 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): else: return t.to(f32) * group_broadcast(scale, t.shape) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens @@ -894,18 +912,28 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dp, K) return (workspace13, workspace2, output, a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1c497fa5521b..e58a9e568d4a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1394,9 +1394,9 @@ def fused_experts(hidden_states: torch.Tensor, # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used( - ) or _valid_deep_gemm(hidden_states, w1, w2) - if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): + if (allow_deep_gemm and use_fp8_w8a8 + and (is_blackwell_deep_gemm_e8m0_used() + or _valid_deep_gemm(hidden_states, w1, w2))): assert apply_router_weight_on_input is False assert is_act_and_mul, ( "DeepGemm only supports is_act_and_mul=True for now.") @@ -1905,7 +1905,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): # Check constraints. if self.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 6b5284dc6c96..312befe2c1d7 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional import torch @@ -8,7 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import extract_required_args from vllm.utils import has_triton_kernels logger = init_logger(__name__) @@ -160,12 +159,16 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): num_dispatchers: int, w1_precision: "PrecisionConfig", w2_precision: "PrecisionConfig", + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], ): super().__init__(quant_config) self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers self.w1_precision = w1_precision self.w2_precision = w2_precision + self.w1_bias = w1_bias + self.w2_bias = w2_bias @property def activation_formats( @@ -219,11 +222,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): - w1_bias, w2_bias = (extract_required_args(extra_expert_args, - ["w1_bias", "w2_bias"])) - return triton_kernel_fused_experts( output, hidden_states, @@ -240,8 +239,8 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, - w1_bias=w1_bias, - w2_bias=w2_bias, + w1_bias=self.w1_bias, + w2_bias=self.w2_bias, w1_precision=self.w1_precision, w2_precision=self.w2_precision, a1_scale=a1q_scale, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36e75825853e..c3c6e4782750 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -37,7 +37,6 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, round_up) -from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -49,9 +48,6 @@ if current_platform.is_cuda_alike(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) - if has_flashinfer(): - from .flashinfer_cutlass_prepare_finalize import ( - FlashInferCutlassMoEPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -80,7 +76,12 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - moe: FusedMoEConfig + # TODO(bnell): also pass quant_config? + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.moe = moe + self.fused_experts: Optional[Callable] = None + self.topk_indices_dtype = None @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -99,16 +100,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): return False @staticmethod - def maybe_make_prepare_finalize( - moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: + def _maybe_make_prepare_finalize( + moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None - if moe.use_flashinfer_cutlass_kernels: - prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( - quant_dtype=moe.quant_dtype, ) + assert not moe.use_flashinfer_cutlass_kernels, \ + "Must be created in modelopt.py" + if moe.use_pplx_kernels: hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, @@ -188,14 +189,25 @@ class FusedMoEMethodBase(QuantizeMethodBase): return prepare_finalize - def init_prepare_finalize(self, moe: FusedMoEConfig): - self.moe = moe - prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize( - self.moe) + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[FusedMoEPrepareAndFinalize]: + if moe.moe_parallel_config.use_all2all_kernels: + return FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + else: + return None + + def init_prepare_finalize(self): + assert self.moe is not None + prepare_finalize = self.maybe_make_prepare_finalize(self.moe) - self.topk_indices_dtype = None if prepare_finalize is not None: - logger.debug("%s", prepare_finalize.__class__.__name__) + logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, + self, id(self)) + assert self.topk_indices_dtype is None + assert self.fused_experts is None, \ + f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, self.moe) self.fused_experts = FusedMoEModularKernel( @@ -214,12 +226,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") - def maybe_swap_experts_impl( - self, - moe_parallel_config: FusedMoEParallelConfig, - ): - pass - @abstractmethod def apply( self, @@ -251,10 +257,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: FusedMoEConfig): - super().__init__() - self.fused_experts = fused_experts # type: ignore - self.topk_indices_dtype = None - self.moe = moe + super().__init__(moe) self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: @@ -266,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, + # TODO(bnell): Remove. Every layer should have an moe config object. moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: if (prepare_finalize.activation_format == @@ -474,9 +478,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - else: - # add w1_bias/w2_bias to kwargs if they exist - kwargs = dict( + elif self.fused_experts is not None: + if self.has_bias: + raise ValueError( + "FusedMoEModularKernel does not support bias.") + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -488,17 +494,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): global_num_experts=global_num_experts, expert_map=expert_map, ) - if isinstance(self.fused_experts, - FusedMoEModularKernel) and self.has_bias: - raise ValueError( - "FusedMoEModularKernel does not support bias.") - if self.has_bias: - kwargs.update({ - "w1_bias": getattr(layer, "w13_bias", None), - "w2_bias": getattr(layer, "w2_bias", None), - }) - - return self.fused_experts(**kwargs) + else: + assert fused_experts is not None + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) def forward_cpu( self, @@ -868,8 +879,6 @@ class FusedMoE(CustomOp): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) - if isinstance(self.quant_method, FusedMoEMethodBase): - self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6262904e4dca..2ea6383d5ae9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from math import prod -from typing import Any, Optional, final +from typing import Optional, final import torch @@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC): @abstractmethod def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC): raise NotImplementedError @abstractmethod - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + ) -> None: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC): workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): """ This function computes the intermediate result of a Mixture of Experts @@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module): f"{fused_experts.activation_formats[0]}") def _do_fused_experts( - self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, - a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: + self, + fused_out: Optional[torch.Tensor], + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module): workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) return fused_out @@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) + # TODO(bnell): get rid of one level here, update slice functions + # to nops on num_chunks==1 + if not self.fused_experts.supports_chunking() or num_chunks == 1: return self._do_fused_experts( fused_out=None, @@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) # Chunking required case assert num_chunks > 1 @@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_num_tokens=c_expert_num_tokens, expert_num_tokens_cpu=c_expert_num_tokens_cpu) - m = None - if extra_expert_args is not None and 'm' in extra_expert_args: - m = extra_expert_args.get('m') - - if extra_expert_args is not None: - chunked_extra_expert_args = extra_expert_args - else: - chunked_extra_expert_args = {} - for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( slice_input_tensors(chunk_idx)) @@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_tokens_meta, c_topk_ids, local_num_experts, expert_map) - s = chunk_idx * CHUNK_SIZE - e = min(s + CHUNK_SIZE, M) - - if m is not None: - chunked_extra_expert_args['m'] = e - s self._do_fused_experts( fused_out=slice_output_tensor(chunk_idx), a1=a1, @@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=chunked_extra_expert_args) + ) return fused_out @@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module): a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - extra_expert_args: Optional[dict] = None, - extra_prepare_args: Optional[dict] = None, - extra_finalize_args: Optional[dict] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module): - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. - - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to - fused_experts.apply. - - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass - to prepare. - - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass - to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_map, apply_router_weight_on_input, self.fused_experts.quant_config, - extra_prepare_args, ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. @@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) self.prepare_finalize.finalize( - output, fused_out, topk_weights, topk_ids, + output, + fused_out, + topk_weights, + topk_ids, apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), - extra_finalize_args) + ) return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 46931f2dd7c7..401f37922b7b 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional, Union import pplx_kernels as pplx import torch @@ -21,7 +21,7 @@ def pplx_hidden_dim_scale_bytes( max_num_tokens: int, hidden_dim: int, in_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, block_shape: Optional[list[int]], ): @@ -32,6 +32,7 @@ def pplx_hidden_dim_scale_bytes( # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to 4 * sizeof(float32) (x4 for alignment) if quant_dtype is not None: + assert isinstance(quant_dtype, torch.dtype) assert quant_dtype.itemsize == 1 hidden_dim_bytes = hidden_dim * quant_dtype.itemsize elem_size = torch.float32.itemsize @@ -89,12 +90,16 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.num_dispatchers_ def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -213,11 +218,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return expert_x, expert_x_scale, expert_tokens_meta, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 696c7cdba9a7..567a0a88fec0 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -38,7 +38,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -50,32 +49,26 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - if (extra_prepare_args is not None - and extra_prepare_args.get("skip_quant", True)): - # Skip quantization if explicitly requested - return a1, None, None, None, None - a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: - if (extra_finalize_args is not None - and extra_finalize_args.get("skip_weight_reduce", True)): - assert output.shape == fused_expert_output.shape - output.copy_(fused_expert_output) - else: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - weight_and_reduce_impl.apply( - output=output, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input) + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 9d0ff2e06190..486ca881df48 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -119,18 +119,28 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): local_num_experts, expert_tokens_meta) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) or is_blackwell_deep_gemm_e8m0_used())) @@ -158,5 +168,4 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2, expert_tokens_meta, apply_router_weight_on_input, - extra_expert_args, ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 966471b5c59b..4c3e700ad399 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Any, Optional, Union +from typing import Optional, Union import torch @@ -189,7 +189,7 @@ def moe_kernel_quantize_input( return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) - elif quant_dtype == torch.uint8: # nvfp4 + elif quant_dtype == "nvfp4": return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled) @@ -252,17 +252,3 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" - - -def extract_required_args( - extra_args: Optional[dict[str, Any]], - required_keys: list[str], -) -> tuple[Any, ...]: - if extra_args is None: - raise ValueError("`extra_args` must be provided.") - - missing_keys = [k for k in required_keys if k not in extra_args] - if missing_keys: - raise ValueError(f"Missing keys in `extra_args`: {missing_keys}") - - return tuple(extra_args[k] for k in required_keys) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index a9e967e608e9..fb285413ba9e 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: - return AWQMoEMethod(quant_args_marlin) + return AWQMoEMethod(quant_args_marlin, layer.moe) from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) @@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig): } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) - return GPTQMarlinMoEMethod(quant_args_marlin) + return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index fe42e26a1706..af602eb9aca3 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig): } awq_marlin_config = AWQMarlinConfig.from_config( marlin_compatible_config_dict) - return AWQMoEMethod(awq_marlin_config) + return AWQMoEMethod(awq_marlin_config, layer.moe_config) return None diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index ed7ffb21e85a..287d66b06d6e 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig): "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - return AWQMoEMethod(self) + return AWQMoEMethod(self, layer.moe_config) return None @classmethod @@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase): class AWQMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: AWQMarlinConfig): + def __init__( + self, + quant_config: AWQMarlinConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config if self.quant_config.weight_bits != 4: raise ValueError("AWQMoEMethod only supports 4bit now.") @@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `AWQMoEMethod` yet.") @@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, @@ -535,4 +543,4 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, - workspace=layer.workspace) \ No newline at end of file + workspace=layer.workspace) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 0204ff46852f..b7897a43793c 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -7,6 +7,7 @@ import torch from packaging import version from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig): return UnquantizedLinearMethod() return BitsAndBytesLinearMethod(self) elif isinstance(layer, FusedMoE): - return BitsAndBytesMoEMethod(self) + return BitsAndBytesMoEMethod(self, layer.moe_config) return None @@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): quant_config: The BitsAndBytes quantization config. """ - def __init__(self, quant_config: BitsAndBytesConfig): + def __init__( + self, + quant_config: BitsAndBytesConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) try: import bitsandbytes if version.parse( @@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): raise ImportError("Please install bitsandbytes>=0.46.1 via " "`pip install bitsandbytes>=0.46.1` to use " "bitsandbytes quantizer.") from err - self.topk_indices_dtype = None self.quant_config = quant_config def create_weights( @@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 839942beaf40..42c43cbc03e5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering, QuantizationStrategy) import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa - FlashInferCutlassMoEPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_kernel, - flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -58,6 +59,9 @@ __all__ = [ class CompressedTensorsMoEMethod(FusedMoEMethodBase): + def __init_(self, moe: FusedMoEConfig): + super().__init__(moe) + @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): "WNA16MoE is not supported with actorder=group/dynamic." ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config) + return CompressedTensorsWNA16MoEMethod(quant_config, + layer.moe_config) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") - return CompressedTensorsWNA16MarlinMoEMethod(quant_config) + return CompressedTensorsWNA16MarlinMoEMethod( + quant_config, layer.moe_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4MoeMethod() + return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)): - return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + return CompressedTensorsW8A8Fp8MoEMethod(quant_config, + layer.moe_config) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8MoEMethod(quant_config) + return CompressedTensorsW8A8Int8MoEMethod(quant_config, + layer.moe_config) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") @@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): - def __init__(self): + def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) + super().__init__(moe) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - self.fused_experts = None # type: ignore[assignment] + self.layer = layer def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): layer.w2_input_scale_quant = torch.nn.Parameter( (layer.w2_input_global_scale), requires_grad=False) - def maybe_swap_experts_impl(self, moe_parallel_config): + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: if not self.allow_flashinfer: - return - self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config) + return super().maybe_make_prepare_finalize(moe) - def select_gemm_impl(self, prepare_finalize, moe): + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return the appropriate GEMM experts implementation.""" - assert moe is not None and prepare_finalize is not None - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 - select_nvfp4_gemm_impl) - - return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) + experts = select_nvfp4_gemm_impl( + moe, + g1_alphas=self.layer.g1_alphas, + g2_alphas=self.layer.g2_alphas, + a1_gscale=self.layer.w13_input_scale_quant, + a2_gscale=self.layer.w2_input_scale_quant, + allow_flashinfer=self.allow_flashinfer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def apply( self, @@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsW4A4MoeMethod` yet.") @@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, ) if self.use_marlin: @@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): # FlashInfer fused experts path if self.fused_experts is not None: - return flashinfer_fp4_cutlass_moe_forward( - self.fused_experts, - layer, - x, - topk_weights, - topk_ids, + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - device=x.device, apply_router_weight_on_input=apply_router_weight_on_input).to( x.dtype) @@ -384,15 +419,16 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( "weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( "input_activations") - self.topk_indices_dtype = None per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR and self.input_quant.strategy @@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.weight_quant, self.input_quant) self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) - self.fused_experts = None # type: ignore[assignment] self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -614,25 +649,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ) -> FusedMoEPermuteExpertsUnpermute: # cutlass path if self.use_cutlass: - from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 + from vllm.model_executor.layers.fused_moe import ( + CutlassBatchedExpertsFp8, CutlassExpertsFp8) - use_batched_format = (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts) + experts: FusedMoEPermuteExpertsUnpermute num_dispatchers = prepare_finalize.num_dispatchers() - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) - logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) - - experts = CutlassExpertsFp8( - num_experts, - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - num_dispatchers=num_dispatchers, - use_batched_format=use_batched_format, - ) + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + logger.debug("CutlassBatchedExpertsFp8(%s)", + self.__class__.__name__) + experts = CutlassBatchedExpertsFp8( + moe.num_local_experts, + num_dispatchers, + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ) + else: + logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) + experts = CutlassExpertsFp8( + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ) self.disable_expert_map = (num_dispatchers > 1 or not experts.supports_expert_map()) @@ -834,9 +875,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( "weights") @@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( hidden_states=x, @@ -975,9 +1021,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, @@ -1279,9 +1330,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsWNA16MoEMethod` yet.") @@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 47eca80609e0..3e43caa4cbf7 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -6,7 +6,8 @@ from typing import Any, Callable, Optional import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig): if isinstance(layer, LinearBase): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - return ExpertsInt8MoEMethod(self) + return ExpertsInt8MoEMethod(self, layer.moe_config) return None class ExpertsInt8MoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: ExpertsInt8Config): + def __init__( + self, + quant_config: ExpertsInt8Config, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `ExpertsInt8MoEMethod` yet.") @@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dbd523428695..a49744913251 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from typing import TYPE_CHECKING, Any, Callable, Optional import torch @@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self) + return Fp8MoEMethod(self, layer.moe_config) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config): - - from vllm.model_executor.layers.fused_moe import fused_experts + def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig): + super().__init__(moe) self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None @@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") - self.topk_indices_dtype = None - self.fused_experts = functools.partial( # type: ignore - fused_experts, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm)) - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, topk_group=topk_group, apply_router_weight_on_input=apply_router_weight_on_input) - else: + elif self.fused_experts is not None: return self.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase): a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm)) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 86da04c39989..49d28927d6e7 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import QuantizationMethods @@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig): elif isinstance(layer, VocabParallelEmbedding): return GGUFEmbeddingMethod(self) elif isinstance(layer, FusedMoE): - return GGUFMoEMethod(self) + return GGUFMoEMethod(self, layer.moe_config) return None @@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase): quant_config: The GGUF quantization config. """ - def __init__(self, quant_config: GGUFConfig): + def __init__( + self, + quant_config: GGUFConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ): + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `GGUFMoEMethod` yet.") @@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, topk_weights, topk_ids, layer.w13_qweight_type.weight_type, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3299221e3af3..bd14ab9ef6c6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) @@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase): """MoE Marlin method with quantization.""" - def __init__(self, quant_config: GPTQMarlinConfig) -> None: + def __init__( + self, + quant_config: GPTQMarlinConfig, + moe: FusedMoEConfig, + ) -> None: + super().__init__(moe) self.quant_config = quant_config if self.quant_config.quant_type.size_bits == 4: self.quant_type = scalar_types.uint4b8 @@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `GPTQMarlinMoEMethod` yet.") @@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 22fbbab00e91..e0f462b36976 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -12,7 +12,9 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_kernel, - flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) @@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptFp8MoEMethod(self) + return ModelOptFp8MoEMethod(self, layer.moe_config) return None @@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptFp8Config) -> None: + def __init__( + self, + quant_config: ModelOptFp8Config, + moe: FusedMoEConfig, + ) -> None: + super().__init__(moe) self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) @@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") @@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) @@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptNvFp4FusedMoE(self) + return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) return None @@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): quant_config: NVFP4 Quant Config """ - def __init__(self, quant_config: ModelOptNvFp4Config) -> None: - self.quant_config = quant_config + def __init__( + self, + quant_config: ModelOptNvFp4Config, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) + super().__init__(moe) + self.quant_config = quant_config + self.layer = layer _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer @@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.fused_experts: Optional[ mk.FusedMoEModularKernel] = None # type: ignore[assignment] - def maybe_swap_experts_impl( + def maybe_make_prepare_finalize( self, - moe_parallel_config: FusedMoEParallelConfig, - ): + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: if not self.allow_flashinfer: - return - self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config) + return super().maybe_make_prepare_finalize(moe) - # This method update self.fused_experts - # only prepare_finalize is not None call select_gemm_impl - # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert - # when it's not called(TP case), we still have 2 kernels to use. - def select_gemm_impl(self, prepare_finalize, - moe) -> mk.FusedMoEPermuteExpertsUnpermute: + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize - assert moe is not None and prepare_finalize is not None - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 - select_nvfp4_gemm_impl) - - return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + experts = select_nvfp4_gemm_impl( + moe, + g1_alphas=self.layer.g1_alphas, + g2_alphas=self.layer.g2_alphas, + a1_gscale=self.layer.w13_input_scale_quant, + a2_gscale=self.layer.w2_input_scale_quant, + allow_flashinfer=self.allow_flashinfer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def uses_weight_scale_2_pattern(self) -> bool: """ @@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) if self.use_marlin: return torch.ops.vllm.fused_marlin_moe( @@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - device=x.device, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) else: assert self.allow_flashinfer and \ self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - out = flashinfer_fp4_cutlass_moe_forward( - self.fused_experts, - layer, - x, - topk_weights, - topk_ids, + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index c5055a02fa3d..364d1ac314d2 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -7,7 +7,7 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig): else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): - return MoeWNA16Method(self) + return MoeWNA16Method(self, layer.moe_config) return None @@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase): quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __init__(self, quant_config: MoeWNA16Config): + def __init__( + self, + quant_config: MoeWNA16Config, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `MoeWNA16Method` yet.") @@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index dbe6c603c062..3c5d83037cde 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig): class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): - super().__init__() + super().__init__(moe) self.topk_indices_dtype = None self.moe = moe self.use_marlin = self._should_use_marlin() diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6f69210d0861..58f56c6381b3 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -7,7 +7,8 @@ import torch from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( OCP_MX_BLOCK_SIZE) @@ -25,6 +26,9 @@ __all__ = [ class QuarkMoEMethod(FusedMoEMethodBase): + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + @staticmethod def get_moe_method( quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 @@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase): input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): - return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + return QuarkW8A8Fp8MoEMethod(weight_config, input_config, + module.moe_config) elif quant_config._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) + return QuarkW4A4MXFp4MoEMethod(weight_config, input_config, + module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): - def __init__(self, weight_config: dict[str, Any], input_config: dict[str, - Any]): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config @@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") @@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, @@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): - def __init__(self, weight_config: dict[str, Any], input_config: dict[str, - Any]): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config @@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( @@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) out = fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index cceaf9857c40..8bdb50e07b13 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -10,7 +10,8 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig): if isinstance(layer, LinearBase): return RTNLinearMethod(self) elif isinstance(layer, FusedMoE): - return RTNMoEMethod(self) + return RTNMoEMethod(self, layer.moe_config) return None @@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase): class RTNMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: RTNConfig): + def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `RTNMoEMethod` yet.") @@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) weight_bits = self.quant_config.weight_bits group_size = self.quant_config.group_size diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 8ef91eeed406..f5d7c57fe2a8 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -3,33 +3,30 @@ """Utility helpers for NVFP4 + FlashInfer fused-MoE path""" from __future__ import annotations -from typing import Optional - import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) + FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize) from vllm.platforms import current_platform - -logger = init_logger(__name__) +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe __all__ = [ "is_flashinfer_fp4_cutlass_moe_available", "reorder_w1w3_to_w3w1", - "build_flashinfer_fp4_cutlass_moe_kernel", - "flashinfer_fp4_cutlass_moe_forward", + "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda() + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and current_platform.is_cuda() and current_platform.is_device_capability(100)) @@ -49,105 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, dim=dim).contiguous()) -def build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel: - """Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel""" - experts = FlashInferExperts( - use_nvfp4_w4a4=True, - use_dp=moe_parallel_config.dp_size > 1, - ep_rank=moe_parallel_config.ep_rank, - ep_size=moe_parallel_config.ep_size, - tp_rank=moe_parallel_config.tp_rank, - tp_size=moe_parallel_config.tp_size, - ) - logger.debug_once("FlashInferExperts (util)") - return mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8), - experts, - ) - - -def flashinfer_fp4_cutlass_moe_forward( - fused_experts: mk.FusedMoEModularKernel, - layer: torch.nn.Module, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, -) -> torch.Tensor: - """Common forward wrapper for FlashInfer NV-FP4 fused-MoE""" - - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, - layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!") - - a1_gscale = layer.w13_input_scale_quant - a2_gscale = layer.w2_input_scale_quant - - extra_expert_args = { - "g1_alphas": layer.g1_alphas, - "g2_alphas": layer.g2_alphas, - # Avoid confusion with a1_scale and a2_scale - # where are batch size related. - "a1_gscale": a1_gscale, - "a2_gscale": a2_gscale, - "out_dtype": x.dtype, - } - extra_prepare_args = { - "use_dp": layer.dp_size > 1, - "local_tokens": x.shape[0], - "a1_gscale": a1_gscale, - } - extra_finalize_args = { - "use_dp": layer.dp_size > 1, - "local_tokens": x.shape[0], - } - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, - apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args, - extra_prepare_args=extra_prepare_args, - extra_finalize_args=extra_finalize_args, - ) +def build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe: FusedMoEConfig, + a1_gscale: torch.Tensor, +) -> mk.FusedMoEPrepareAndFinalize: + """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" + use_dp = moe.moe_parallel_config.dp_size > 1 + return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale) def select_nvfp4_gemm_impl( - allow_flashinfer: bool, - moe, # FusedMoEConfig - logger): + moe: FusedMoEConfig, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + allow_flashinfer: bool, +) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" - # lazy import - from vllm.distributed import get_ep_group - - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - if allow_flashinfer: - flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_backend != "throughput": - raise ValueError( - f"Only throughput backend is supported for FlashInferExperts, " - f"but got {flashinfer_backend}.") - logger.debug_once( - "Initializing FlashInferExperts with throughput backend.") return FlashInferExperts( - use_nvfp4_w4a4=True, - use_dp=moe.moe_parallel_config.dp_size > 1, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + out_dtype=moe.in_dtype, + quant_dtype="nvfp4", ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank,