diff --git a/CMakeLists.txt b/CMakeLists.txt index cad9f44286536..270c480003e77 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -288,6 +288,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/attention/mla/cutlass_mla_entry.cu") @@ -495,7 +496,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu") + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") @@ -533,7 +536,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible # to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py new file mode 100644 index 0000000000000..0d091b47c3e12 --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -0,0 +1,408 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe +kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit +activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8) +and 16-bit activations. +""" +import nvtx +import torch +import torch.utils.benchmark as benchmark + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) +from vllm.scalar_type import scalar_types +from vllm.utils import FlexibleArgumentParser + +WEIGHT_SHAPES_MOE = { + "nvidia/DeepSeek-R1-FP4": [ + [256, 8, 2048, 7168], + ], +} + +DEFAULT_MODELS = [ + "nvidia/DeepSeek-R1-FP4", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def bench_run(results: list[benchmark.Measurement], model: str, + num_experts: int, topk: int, per_act_token: bool, + per_out_ch: bool, mkn: tuple[int, int, int]): + label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, + mkn)) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + _, a_fp8_scale = ops.scaled_fp8_quant(a) + + w1_fp8q = torch.empty((num_experts, 2 * n, k), + device=device, + dtype=torch.float8_e4m3fn) + w2_fp8q = torch.empty((num_experts, k, n), + device=device, + dtype=torch.float8_e4m3fn) + w1_fp8scale = torch.empty((num_experts, 1, 1), + device=device, + dtype=torch.float32) + w2_fp8scale = torch.empty((num_experts, 1, 1), + device=device, + dtype=torch.float32) + + for expert in range(num_experts): + w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert]) + + w1_fp8q_notransp = w1_fp8q.clone() + w2_fp8q_notransp = w2_fp8q.clone() + w1_fp8q = w1_fp8q.transpose(1, 2) + w2_fp8q = w2_fp8q.transpose(1, 2) + + score = torch.randn((m, num_experts), device=device, dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + quant_blocksize = 16 + w1_blockscale = torch.empty((num_experts, 2 * n, k // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn) + w2_blockscale = torch.empty((num_experts, k, n // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn) + + # n_b_scales = 2 * n if per_out_ch else 1 + # k_b_scales = k if per_out_ch else 1 + w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), + device=device, + dtype=torch.uint8) + w2_fp4 = torch.empty((num_experts, k, n // 2), + device=device, + dtype=torch.uint8) + + w1_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32) + w2_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32) + a1_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32) + a2_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_e = w1[expert] + w2_e = w2[expert] + w1_amax = torch.abs(w1_e).max().to(torch.float32) + w2_amax = torch.abs(w2_e).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1_e, w1_gs[expert]) + + w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2_e, w2_gs[expert]) + + def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, num_repeats: int): + for _ in range(num_repeats): + fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale) + + def run_cutlass_moe_fp4(a: torch.Tensor, w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, w1_gs: torch.Tensor, + w2_gs: torch.Tensor, a1_gs: torch.Tensor, + a2_gs: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, m: int, n: int, k: int, + e: int, device: torch.device, num_repeats: int): + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp4", color="green"): + cutlass_moe_fp4(a=a, + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device) + + def run_cutlass_from_graph( + a: torch.Tensor, a1_gscale: torch.Tensor, w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, + k: int, e: int, device: torch.device): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe_fp4(a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_alphas, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device) + + def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph(a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph(a, w1_fp8q_notransp, w2_fp8q_notransp, + topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, + a_fp8_scale) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + "w1_fp8q_notransp": w1_fp8q_notransp, + "w2_fp8q_notransp": w2_fp8q_notransp, + "w1_fp8scale": w1_fp8scale, + "w2_fp8scale": w2_fp8scale, + "a_fp8_scale": a_fp8_scale, + # Cutlass params + "a": a, + "a1_gscale": a1_gs, + "w1_fp4": w1_fp4, + "w1_blockscale": w1_blockscale, + "w1_alphas": w1_gs, + "a2_gscale": a2_gs, + "w2_fp4": w2_fp4, + "w2_blockscale": w2_blockscale, + "w2_alphas": w2_gs, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "m": m, + "n": n, + "k": k, + "e": num_experts, + "device": device, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe_fp4": run_cutlass_moe_fp4, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, + topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + + run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_gs, + w2_gs, a1_gs, a2_gs, topk_weights, topk_ids, m, n, k, + num_experts, device, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run(results, model, num_experts, topk, + per_act_token, per_out_ch, mkn) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark NVFP4 CUTLASS MOE across specified " + "models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", + nargs="+", + type=int, + default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/csrc/ops.h b/csrc/ops.h index 1dfd2e067e851..21c5a9e297400 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -208,6 +208,12 @@ void cutlass_moe_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets); + void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -235,6 +241,12 @@ std::vector cutlass_sparse_compress(torch::Tensor const& a); void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index ddcc48cccab12..9843cd857d48a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -37,12 +37,6 @@ void cutlass_moe_mm_sm90( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void get_cutlass_moe_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); - #endif #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 @@ -53,6 +47,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif +#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ + defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); +#endif + void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -224,7 +227,8 @@ void get_cutlass_moe_mm_data( // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k); diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu new file mode 100644 index 0000000000000..45ec3d29ce045 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -0,0 +1,402 @@ +#include +#include + +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include + +using namespace cute; + +template +__global__ void __get_group_gemm_starts( + ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, + ElementSF** a_scales_offsets, ElementSF** b_scales_offsets, + ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, + ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int, + ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets, + const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes, + const int K, const int N) { + int64_t expert_id = threadIdx.x; + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + // Originally int32_t but upcasting to int64_t to avoid overflow + // during offset calculations + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t sf_offset = static_cast(sf_offsets[expert_id]); + // size for block in block scale. + int64_t group_size = 16; + int64_t m = static_cast(problem_sizes_as_shapes[expert_id * 3]); + int64_t n = static_cast(problem_sizes_as_shapes[expert_id * 3 + 1]); + int64_t k = static_cast(problem_sizes_as_shapes[expert_id * 3 + 2]); + assert((m >= 0 && n == N && k == K && k % 2 == 0) && + "unexpected problem sizes"); + + int64_t half_k = static_cast(k / 2); + int64_t group_k = static_cast(k / group_size); + // Shape of A as uint8/byte = [M, K // 2] + // Shape of B as uint8/byte = [E, N, K // 2] + a_offsets[expert_id] = a_base_as_int + expert_offset * half_k; + + b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k; + // Shape of C = [M, N] + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + // Shape of a_scale = [sum(sf_sizes), K // group_size] + a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k; + + assert((reinterpret_cast(a_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + + // Shape of B scale = [E, N, K // group_size] + b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k; + assert((reinterpret_cast(b_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + // Shape of alpha = [E] + alpha_offsets[expert_id] = alphas_base_as_int + expert_id; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); +} + +#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \ + TENSOR_C_TYPE, C_TYPE, LayoutSFA, \ + LayoutSFB, ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + __get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(a_starts.data_ptr()), \ + static_cast(b_starts.data_ptr()), \ + static_cast(out_starts.data_ptr()), \ + static_cast(a_scales_starts.data_ptr()), \ + static_cast(b_scales_starts.data_ptr()), \ + static_cast(alpha_starts.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast(alphas.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(sf_offsets.data_ptr()), \ + static_cast(problem_sizes.data_ptr()), K, N); \ + } + +template +void run_get_group_gemm_starts( + const torch::Tensor& a_starts, const torch::Tensor& b_starts, + const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts, + const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts, + const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, + /*these are used for their base addresses*/ + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor const& out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& alphas, + torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets, + torch::Tensor const& problem_sizes, int M, int N, int K) { + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + TORCH_CHECK(out_tensors.size(1) == N, + "Output tensor shape doesn't match expected shape"); + TORCH_CHECK(K / 2 == b_tensors.size(2), + "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" + " dimension must match"); + if (false) { + } + //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, + // ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16, + cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t, + cutlass::float_ue4m3_t, torch::kFloat16, + half, LayoutSFA, LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm100; + using EpilogueOperatorClass = + cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag + using MainloopOperatorClass = + cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + + using ClusterShape = Shape<_1, _1, _1>; + struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _128, _128>; + using KernelSchedule = cutlass::gemm:: + KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + }; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape, + ClusterShape, Shape<_128, _64>, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentD, + typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA, + ElementB, LayoutB*, AlignmentB, ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = Gemm1SM; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + torch::Tensor c_strides1 = + torch::full({num_experts}, output.stride(0), options_int); + torch::Tensor a_strides1 = + torch::full({num_experts}, a.stride(0) * 2, options_int); + torch::Tensor b_strides1 = + torch::full({num_experts}, b.stride(1) * 2, options_int); + + run_get_group_gemm_starts( + a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, + layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, + expert_offsets, sf_offsets, problem_sizes, M, N, K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm100GroupParams< + typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides1.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides1.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides1.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides1.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = + reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +#define CHECK_TYPE(x, st, m) \ + TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) \ + TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + // Input validation + CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); + CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); + CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); + CHECK_INPUT(alphas, at::ScalarType::Float, "alphas"); + + TORCH_CHECK(a_blockscale.dim() == 2, + "expected a_blockscale to be of shape [num_experts, rounded_m," + " k // group_size], observed rank: ", + a_blockscale.dim()) + TORCH_CHECK(b_blockscales.dim() == 3, + "expected b_blockscale to be of shape: " + " [num_experts, n, k // group_size], observed rank: ", + b_blockscales.dim()) + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have the shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32."); + + int M = static_cast(a.size(0)); + int N = static_cast(b.size(1)); + int E = static_cast(b.size(0)); + int K = static_cast(2 * b.size(2)); + + if (output.scalar_type() == torch::kBFloat16) { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } else { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel, vLLM must " + "be compiled with ENABLE_NVFP4 for SM100+ and CUDA " + "12.8 or above."); +#endif +} diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu new file mode 100644 index 0000000000000..076c4a085337b --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -0,0 +1,404 @@ +#include + +#include +#include + +#include +#include + +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + // Find index within the experts. + int rowIdx_in_expert = 0; + int expert_idx = 0; + for (int i = 0; i < n_experts; i++) { + if (rowIdx >= input_offset_by_experts[i] && + rowIdx < input_offset_by_experts[i + 1]) { + rowIdx_in_expert = rowIdx - input_offset_by_experts[i]; + expert_idx = i; + break; + } + } + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void quant_impl(void* output, void* output_scale, void* input, + void* input_global_scale, void* input_offset_by_experts, + void* output_scale_offset_by_experts, int m_topk, int k, + int n_experts, cudaStream_t stream) { + // TODO: this multiProcessorCount should be cached. + int device; + cudaGetDevice(&device); + int multiProcessorCount; + cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, + device); + + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(k / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM)); + + cvt_fp16_to_fp4<<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), n_experts); +} + +/*Quantization entry for fp4 experts quantization*/ +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); + +constexpr auto HALF = at::ScalarType::Half; +constexpr auto BF16 = at::ScalarType::BFloat16; +constexpr auto FLOAT = at::ScalarType::Float; +constexpr auto INT = at::ScalarType::Int; +constexpr auto UINT8 = at::ScalarType::Byte; + +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + CHECK_INPUT(output, "output must be a CUDA tensor"); + CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); + CHECK_INPUT(input, "input must be a CUDA tensor"); + CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); + CHECK_INPUT(input_offset_by_experts, + "input_offset_by_experts must be a CUDA tensor"); + CHECK_INPUT(output_scale_offset_by_experts, + "output_scale_offset_by_experts must be a CUDA tensor"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + auto m_topk = input.size(0); + auto k = input.size(1); + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + + auto in_dtype = input.dtype(); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + if (in_dtype == at::ScalarType::Half) { + quant_impl(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, + n_experts, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, + k, n_experts, stream); + } else { + TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); + } +} \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index b1426c43b456b..badbb7e310df0 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, torch::Tensor const& input_sf); #endif +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if defined ENABLE_NVFP4 && ENABLE_NVFP4 return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); +} + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_experts_quant_sm100a( + output, output_scale, input, input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled nvfp4 experts quantization kernel"); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7ca40a5e78276..1dbd11f5f2a5b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -363,6 +363,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // cutlass nvfp4 block scaled group GEMM + ops.def( + "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," + " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," + " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", + {stride_tag}); + ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( @@ -492,6 +500,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! output_scale, Tensor input_scale) -> ()"); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Compute NVFP4 experts quantization. + ops.def( + "scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); + // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices // of the given capability ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py new file mode 100644 index 0000000000000..ae63b379f39d1 --- /dev/null +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from tests.kernels.utils import torch_moe +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1536), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + quant_blocksize = 16 + round_up = lambda x, y: (x + y - 1) // y * y + sf_w1_2n = round_up(2 * n, 128) + sf_w1_k = round_up(k // quant_blocksize, 4) + w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), + device="cuda", + dtype=torch.float8_e4m3fn) + + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + sf_w2_k = round_up(k, 128) + sf_w2_n = round_up(n // quant_blocksize, 4) + w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), + device="cuda", + dtype=torch.float8_e4m3fn) + + w1_q = torch.empty((e, 2 * n, k // 2), + device="cuda", + dtype=torch.uint8) + w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) + w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_amax = torch.abs(w1).max().to(torch.float32) + w2_amax = torch.abs(w2).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1[expert], w1_gs[expert]) + + w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2[expert], w2_gs[expert]) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + + cutlass_output = cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_q, + w1_blockscale=w1_blockscale, + w1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + w2_fp4=w2_q, + w2_blockscale=w2_blockscale, + w2_alphas=(1 / w2_gs), + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=e, + device=a.device, + ) + + # Reference check: + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=quant_blocksize) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + + torch.testing.assert_close(torch_output, + cutlass_output, + atol=1e-1, + rtol=1e-1) + + +if __name__ == "__main__": + test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py new file mode 100644 index 0000000000000..58eaeee1c0b88 --- /dev/null +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.scalar_type import scalar_types + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_nvfp4_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index b08026c5867da..1f49900b2d90b 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import pytest import torch +from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types if not current_platform.has_device_capability(100): pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", @@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES) SEEDS = [42] CUDA_DEVICES = ['cuda:0'] -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -kE2M1ToFloatArray = [ - 0., - 0.5, - 1., - 1.5, - 2., - 3., - 4., - 6., -] - - -def e2m1_to_fp32(int4_value): - signBit = (int4_value & 0x8) - int4_absValue = int4_value & 0x7 - float_result = kE2M1ToFloatArray[int4_absValue] - if (signBit): - float_result = -float_result - return float_result - - -def break_fp4_bytes(a, dtype): - assert (a.dtype == torch.uint8) - m, n = a.shape - a = a.flatten() - # Get upper 4 bits - highHalfByte = (a & 0xF0) >> 4 - # Get lower 4 bits - lowHalfByte = a & 0x0F - fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) - fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) - # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] - out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) - return out - - -def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): - sf_m, sf_k = a_sf_swizzled.shape - m_tiles = (m + 128 - 1) // 128 - f = block_size * 4 - k_tiles = (k + f - 1) // f - tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) - return out[0:m, 0:k] - - -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): - """Dequantize the fp4 tensor back to high precision.""" - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape - k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale - - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) - return out - def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, m, n, dtype, block_size, device): _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert (m_k == n_k) - a_in_dtype = dequantize_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) return torch.matmul(a_in_dtype, b_in_dtype.t()) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0206d4552c8b2..80f5497452197 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -745,10 +745,11 @@ def get_cutlass_moe_mm_data( - output_permutation: Permutation that must be used to shuffle the output after executing the MMs. """ - torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, - input_permutation, output_permutation, - num_experts, n, k) + return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, + output_permutation, + num_experts, n, k) def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, @@ -767,9 +768,41 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, MMs used in the fused MoE operation. - a/b/c_strides: The data strides passed to grouped matrix multiplication. """ - torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales, - b_scales, expert_offsets, problem_sizes, - a_strides, b_strides, c_strides) + return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, + c_strides) + + +def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + alphas: torch.Tensor, problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, + out_dtype: torch.dtype, device: torch.device): + """ + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + the gemms for each combination based on the specified problem sizes. + + This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. + - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized + input and expert weights. + - a_/b_scales: The blockscales in FP8-E4M3 precision + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is + sf_offset[E+1] - sf_offset[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + """ + m_topk = a_tensors.shape[0] + n = b_tensors.shape[1] + c_shape = (m_topk, n) + c = torch.empty(c_shape, device=device, dtype=out_dtype) + torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, + b_scales, alphas, problem_sizes, + expert_offsets, sf_offsets) + return c.to(out_dtype) # aqlm @@ -960,6 +993,57 @@ def scaled_fp4_quant( return output, output_scale +def scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, + expert_map: Optional[torch.Tensor] = None, + MAX_TOKENS_PER_EXPERT: int = 163840, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + packed MoE Inputs. + Args: + input: The input tensor to be quantized to FP4 + expert_map: The expert map tensor + input_global_scale: A scalar scaling factor for the entire tensor. + expert_offsets: The expert offsets tensor + blockscale_offsets: The blockscale offsets tensor + Outputs: + output: The quantized tensor in FP4 + output_scales: The blockscale tensor in FP8-E4M3 + """ + assert not current_platform.is_rocm() + assert input_tensor.ndim == 2, ( + f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') + + input_tensor = input_tensor[ + expert_map] if expert_map is not None else input_tensor + m_numtopk, k = input_tensor.shape + assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for" + f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}") + scales_k = k // 16 + padded_k = (scales_k + (4 - 1)) // 4 + + # output is uint8 and packed fp4 values + output = torch.empty(m_numtopk, + k // 2, + device=input_tensor.device, + dtype=torch.uint8) + output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device) + torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor, + input_global_scale, expert_offsets, + blockscale_offsets) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9829ccdb384ff..53e7769b2042c 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -36,7 +36,7 @@ if HAS_TRITON: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) + cutlass_moe_fp4, cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) @@ -48,4 +48,5 @@ if HAS_TRITON: "get_config_file_name", "grouped_topk", "cutlass_moe_fp8", + "cutlass_moe_fp4", ] diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 960c7f8348571..1b34e952208af 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -"""Fused MoE kernel.""" +""" CUTLASS based Fused MoE kernels.""" from typing import Optional import torch from vllm import _custom_ops as ops +from vllm.scalar_type import scalar_types #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -178,3 +179,126 @@ def cutlass_moe_fp8( if not apply_router_weight_on_input: c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) return c2.sum(dim=1) + + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +MAX_TOKENS_PER_EXPERT = 65536 + + +def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, + device: torch.device): + """ + MoE implementation for FP4 Inputs + + # Gemm 1 + a: Input tensor: [m, k] (half/bfloat16) + a1_gscale: Activation scale per expert: [e] (float32) + w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] + w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) + (Note: `n` is the up projection output dim, `k` is the input dim in + full precision) + w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) + (Block size = 16 for NVFP4) + + # Gemm 2 + a2_gscale: Activation scale per expert: [e] + w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] + w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) + w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 + + topk_weights: [m, topk] dtype: float8 + topk_ids: [m, topk] dtype: float8 + + m, n, k: Unquantized weight shapes, dtype: int + e: number of experts, dtype: int + + assumes that topk < k < n to satisfy - up/down projection expectations. + """ + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" + assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" + assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 + and w2_blockscale.ndim + == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") + m_a, k_a = a.shape + e_w1, nx2_w1, half_k_w1 = w1_fp4.shape + e_w2, k_w2, half_n_w2 = w2_fp4.shape + + assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", + " between weights.") + assert (k_a // 2 == half_k_w1 + and k == k_w2), ("Hidden size mismatch between a, w1 and w2") + assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " + "expected `n`") + assert (m == m_a), "input shape mismatch" + assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert (topk_weights.shape[0] == m and topk_ids.shape[0] + == m), ("topk must be provided for each row of a") + assert (m <= MAX_TOKENS_PER_EXPERT), ( + f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})" + f" for cutlass_moe_fp4, observed m = {m}") + out_dtype = a.dtype + num_topk = topk_ids.shape[1] + + expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,2n,k)) + problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,n,k)) + problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + # problem shapes should have [m, n, k] + # Note that problem sizes are based on logical number of elements. + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, e, n, k) + + tokens_per_expert = problem_sizes1[:, 0] + rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128 + blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device) + blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0) + + rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( + a, + a1_gscale, + expert_offsets, + blockscale_offsets, + num_topk, + expert_map=a_map, + MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT) + + c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, + w1_blockscale, w1_alphas, problem_sizes1, + expert_offsets[:-1], blockscale_offsets[:-1], + out_dtype, device) + del rep_a_fp4, rep_a_blockscale + # hidden size dimension is split to one halfpytho sized tensor. + intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2), + device=device, + dtype=out_dtype) + + torch.ops._C.silu_and_mul(intermediate, c1) + + int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( + intermediate, + a2_gscale, + expert_offsets, + blockscale_offsets, + num_topk, + MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT) + + c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, + w2_alphas, problem_sizes2, expert_offsets[:-1], + blockscale_offsets[:-1], out_dtype, device) + del int_fp4, int_blockscale + out = (c2[c_map].view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).half()).sum(dim=1) + return out.to(dtype=out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 35994c8ac6af0..5337ff0037da4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -643,7 +643,7 @@ class FusedMoE(torch.nn.Module): expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return - + quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -697,8 +697,9 @@ class FusedMoE(torch.nn.Module): # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: + if ("compressed" in quant_method_name.lower() + and param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " @@ -718,6 +719,22 @@ class FusedMoE(torch.nn.Module): tp_rank=self.tp_rank) return + if "ModelOpt" in quant_method_name: + if ('weight_scale_2' in weight_name + or 'input_scale' in weight_name): + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + elif "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 828447dd10193..e9b16b8a0acd0 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import Module @@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -210,25 +212,37 @@ class ModelOptNvFp4Config(QuantizationConfig): "`hf_quant_config.json` file for your model's " "quant configuration.") is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) - kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] - group_size = quant_config["group_size"] - exclude_modules = quant_config["exclude_modules"] - if not (group_size and kv_cache_quant_algo and exclude_modules): + if ("group_size" and "kv_cache_quant_algo" + and "exclude_modules") not in quant_config: raise ValueError("NVFP4 quantization requires group size and " "kv_cache_quant_algo specified in " "hf_quant_config.json") + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + group_size = quant_config["group_size"] + exclude_modules = quant_config["exclude_modules"] return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size) + def is_layer_excluded(self, prefix: str, exclude_modules: List): + import re + for pattern in exclude_modules: + regex_str = pattern.replace('.', r'\.').replace('*', r'.*') + if re.fullmatch(regex_str, prefix): + return True + return False + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.exclude_modules): + if (is_layer_skipped(prefix, self.exclude_modules) + or self.is_layer_excluded(prefix, self.exclude_modules)): return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FusedMoE): + return ModelOptNvFp4FusedMoE(self) return None @@ -409,3 +423,235 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): if bias is not None: out = out + bias return out.view(*output_shape) + + +class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): + """ + MoE Method for FP4 Quantization. + Args: + quant_config: NVFP4 Quant Config + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + weight_loader = extra_weight_attrs.get("weight_loader") + # GEMM 1 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight", w2_weight) + + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // + self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + w13_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + + w2_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + + w13_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # GEMM 1 + + assert torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), ( + "Expected w1_weight_scale_2 to equal w3_weight_scale_2") + + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, + requires_grad=False) + + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( + torch.float32) + layer.g1_alphas = Parameter( + (w13_input_scale * w13_weight_scale_2).to(torch.float32), + requires_grad=False) + + assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w13_blockscale_swizzled = self.swizzle_blockscale( + layer.w13_weight_scale) + + layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False) + + # GEMM 2 + layer.g2_alphas = Parameter( + (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = Parameter( + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + + assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + + layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, + requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for ModelOptNvFp4FusedMoE.") + assert expert_map is None, ("Expert Parallelism /expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype)