diff --git a/CMakeLists.txt b/CMakeLists.txt index a14496e035d9a..c46fb18d7bfef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -799,24 +799,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") else() cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") - message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " - "if you intend on running FP8 quantized MoE models on Blackwell.") - else() - message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() # # Machete kernels diff --git a/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu deleted file mode 100644 index 6c8f6309ef43f..0000000000000 --- a/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu +++ /dev/null @@ -1,373 +0,0 @@ -#include "core/registration.h" - -#include -#include - -#include -#include -#include - -#include "cute/tensor.hpp" -#include "cutlass/tensor_ref.h" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" - -#include "cutlass/util/command_line.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/device/gemm.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/gett.hpp" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include - -using namespace cute; - -template -__global__ void get_ggemm_starts( - int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, - ElementC** out_offsets, ElementAccumulator** a_scale_offsets, - ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int, - ElementAB* b_base_as_int, ElementC* out_base_as_int, - ElementAccumulator* a_scale_base_as_int, - ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int, - LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) { - int expert_id = threadIdx.x; - - if (expert_id >= gridDim.x * blockDim.x) { - return; - } - - int m = problem_sizes[expert_id * 3]; - int n = problem_sizes[expert_id * 3 + 1]; - int k = problem_sizes[expert_id * 3 + 2]; - - int32_t expert_offset = expert_offsets[expert_id]; - int a_stride = expert_offset * k; - int b_stride = expert_id * k * n; - int a_scale_stride = expert_offset * k / 128; - int b_scale_stride = expert_id * k * n / 128 / 128; - - a_offsets[expert_id] = a_base_as_int + a_stride; - b_offsets[expert_id] = b_base_as_int + b_stride; - out_offsets[expert_id] = out_base_as_int + expert_offset * n; - a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride; - b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride; - - LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; - LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; - - *layout_sfa_ptr = - ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); - *layout_sfb_ptr = - ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); -} - -#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \ - ScaleConfig) \ - else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ - get_ggemm_starts<<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ - static_cast(a_ptrs.data_ptr()), \ - static_cast(b_ptrs.data_ptr()), \ - static_cast(out_ptrs.data_ptr()), \ - static_cast(a_scales_ptrs.data_ptr()), \ - static_cast(b_scales_ptrs.data_ptr()), \ - static_cast(a_tensors.data_ptr()), \ - static_cast(b_tensors.data_ptr()), \ - static_cast(out_tensors.data_ptr()), \ - static_cast(a_scales.data_ptr()), \ - static_cast(b_scales.data_ptr()), \ - reinterpret_cast(layout_sfa.data_ptr()), \ - reinterpret_cast(layout_sfb.data_ptr()), \ - static_cast(problem_sizes.data_ptr())); \ - } - -template -void run_get_ggemm_starts( - torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, - torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, - torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, - torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, - torch::Tensor out_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& layout_sfa, - torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) { - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0); - TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0); - - int num_experts = (int)expert_offsets.size(0); - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - - if (false) { - } - __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, - LayoutSFB, ScaleConfig) - __CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA, - LayoutSFB, ScaleConfig) - else { - TORCH_CHECK(false, "Unsupported output tensor type"); - } -} - -template -void run_blockwise_scaled_group_mm( - torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs, - const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs, - const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a, - const torch::Tensor& stride_b, const torch::Tensor& stride_c, - const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, - const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { - using ProblemShape = cutlass::gemm::GroupProblemShape>; - - // Types - using ElementA = cutlass::float_e4m3_t; - using ElementB = cutlass::float_e4m3_t; - using ElementC = OutType; - using ElementD = ElementC; - using ElementAccumulator = float; - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = LayoutD; - - // Alignments - static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using ArchTag = cutlass::arch::Sm100; - using OperatorClass = cutlass::arch::OpClassTensorOp; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape, - typename ScheduleConfig::ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*, - AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, - cute::tuple, - AlignmentA, ElementB, - cute::tuple, - AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape, - typename ScheduleConfig::ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - typename ScheduleConfig::KernelSchedule>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - - using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; - int num_experts = (int)expert_offsets.size(0); - - Gemm gemm_op; - - // Mainloop Arguments - typename GemmKernel::MainloopArguments mainloop_args{ - static_cast(a_ptrs.data_ptr()), - static_cast(stride_a.data_ptr()), - static_cast(b_ptrs.data_ptr()), - static_cast(stride_b.data_ptr()), - static_cast(a_scales_ptrs.data_ptr()), - reinterpret_cast( - layout_sfa.data_ptr()), - static_cast(b_scales_ptrs.data_ptr()), - reinterpret_cast( - layout_sfb.data_ptr())}; - - int device_id = a_ptrs.device().index(); - static const cutlass::KernelHardwareInfo hw_info{ - device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - device_id)}; - - // Epilogue Arguments - typename GemmKernel::EpilogueArguments epilogue_args{ - {}, // epilogue.thread - nullptr, - static_cast(stride_c.data_ptr()), - static_cast(out_ptrs.data_ptr()), - static_cast(stride_c.data_ptr())}; - - UnderlyingProblemShape* problem_sizes_as_shapes = - static_cast(problem_sizes.data_ptr()); - - // Gemm Arguments - typename GemmKernel::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {num_experts, problem_sizes_as_shapes, nullptr}, - mainloop_args, - epilogue_args, - hw_info}; - - at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()}; - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); - - auto can_implement_status = gemm_op.can_implement(args); - TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM"); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); - - status = gemm_op.run(stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); -} - -template -void blockwise_scaled_group_mm_dispatch_shape( - torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { - struct MmaConfig { - using ElementA = cutlass::float_e4m3_t; - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; - using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< - 1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; - using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); - using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); - using LayoutC = cutlass::layout::RowMajor; - using MmaTileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _1, _1>; - }; - - int num_experts = (int)expert_offsets.size(0); - - auto a_ptrs = torch::empty( - {num_experts}, - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto b_ptrs = torch::empty( - {num_experts}, - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto out_ptrs = torch::empty( - {num_experts}, - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto a_scales_ptrs = torch::empty( - {num_experts}, - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto b_scales_ptrs = torch::empty( - {num_experts}, - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - - auto layout_sfa = torch::empty( - {num_experts, 5}, - torch::TensorOptions().dtype(torch::kInt32).device(a.device())); - auto layout_sfb = torch::empty( - {num_experts, 5}, - torch::TensorOptions().dtype(torch::kInt32).device(a.device())); - - auto stride_a = torch::full( - {num_experts}, a.size(1), - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto stride_b = torch::full( - {num_experts}, a.size(1), - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto stride_c = torch::full( - {num_experts}, output.size(1), - torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - - torch::TensorOptions options_int = - torch::TensorOptions().dtype(torch::kInt64).device(a.device()); - - run_get_ggemm_starts( - expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a, - b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes); - - run_blockwise_scaled_group_mm( - out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a, - stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes, - expert_offsets); -} - -void cutlass_blockwise_scaled_grouped_mm( - torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, - "problem_sizes must have shape (num_experts, 3)"); - TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), - "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, - "problem_sizes must be int32"); - TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, - "a must be kFloat8_e4m3fn"); - TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, - "b must be kFloat8_e4m3fn"); - TORCH_CHECK(output.scalar_type() == torch::kBFloat16 || - output.scalar_type() == torch::kHalf, - "output must be bfloat16 or half"); - TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, - "scales_a must be float32"); - TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, - "scales_b must be float32"); - TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, - "expert_offsets must be int32"); - - TORCH_CHECK(output.dim() == 2, "output must be 2D tensor"); - TORCH_CHECK(a.dim() == 2, "a must be 2D tensor"); - TORCH_CHECK(b.dim() == 3, "b must be 3D tensor"); - TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor"); - TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor"); - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, - "problem_sizes must have shape (num_experts, 3)"); - TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), - "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, - "problem_sizes must be int32"); - TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor"); - -#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100 - if (output.scalar_type() == torch::kBFloat16) { - blockwise_scaled_group_mm_dispatch_shape( - output, a, b, scales_a, scales_b, problem_sizes, expert_offsets); - } else if (output.scalar_type() == torch::kFloat16) { - blockwise_scaled_group_mm_dispatch_shape( - output, a, b, scales_a, scales_b, problem_sizes, expert_offsets); - } else { - TORCH_CHECK(false, "Unsupported output tensor type"); - } -#endif -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("cutlass_blockwise_scaled_grouped_mm", - &cutlass_blockwise_scaled_grouped_mm); -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 83d4943d62776..461f74ca184fd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -416,13 +416,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor alpha) -> ()"); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); - // cutlass blockwise scaledgroup GEMM - ops.def( - "cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, " - "Tensor scales_a, Tensor scales_b, " - "Tensor problem_sizes, Tensor expert_offsets) -> ()"); - // conditionally compiled so impl registration is in source file - // cutlass nvfp4 block scaled group GEMM ops.def( "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py deleted file mode 100644 index 1c10cb3b2c699..0000000000000 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ /dev/null @@ -1,92 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# DeepGEMM Style Cutlass Grouped GEMM Test -# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py - -import random - -import pytest -import torch - -from tests.kernels.moe.utils import per_token_cast_to_fp8 -from tests.kernels.utils import baseline_scaled_mm -from vllm import _custom_ops as ops -from vllm.platforms import current_platform -from vllm.utils.deep_gemm import per_block_cast_to_fp8 -from vllm.utils.math_utils import cdiv - - -@pytest.mark.parametrize( - "num_groups, expected_m_per_group, k, n", - [ - (4, 8192, 7168, 4096), - (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), - (8, 4096, 2048, 7168), - (32, 1024, 7168, 4096), - (32, 1024, 2048, 7168), - ], -) -@pytest.mark.parametrize("out_dtype", [torch.float16]) -@pytest.mark.skipif( - (lambda x: x is None or x.to_int() != 100)( - current_platform.get_device_capability() - ), - reason="Block Scaled Grouped GEMM is only supported on SM100.", -) -def test_cutlass_grouped_gemm( - num_groups: int, - expected_m_per_group: int, - k: int, - n: int, - out_dtype: torch.dtype, -): - device = "cuda" - alignment = 128 - group_ms = [ - int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups) - ] - m = sum([cdiv(m, alignment) * alignment for m in group_ms]) - - x = torch.randn((m, k), device=device, dtype=out_dtype) - y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype) - out = torch.empty((m, n), device=device, dtype=out_dtype) - ref_out = torch.randn((m, n), device=device, dtype=out_dtype) - - ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m] - pb_size = [] - for i in range(num_groups): - pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k]) - problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32) - expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32) - - x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = ( - torch.empty_like(y, dtype=torch.float8_e4m3fn), - torch.empty( - (num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float - ), - ) - for i in range(num_groups): - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128]) - - for i in range(num_groups): - a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]] - a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]] - b = y_fp8[0][i].t() - b_scale = y_fp8[1][i].t() - baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) - ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline - - ops.cutlass_blockwise_scaled_grouped_mm( - out, - x_fp8[0], - y_fp8[0], - x_fp8[1], - y_fp8[1], - problem_sizes, - expert_offsets[:-1], - ) - - torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index cf7f17a033be3..78bd8d4e64115 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -788,20 +788,6 @@ def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) -def cutlass_blockwise_scaled_grouped_mm( - output: torch.Tensor, - a: torch.Tensor, - b: torch.Tensor, - scales_a: torch.Tensor, - scales_b: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, -): - torch.ops._C.cutlass_blockwise_scaled_grouped_mm( - output, a, b, scales_a, scales_b, problem_sizes, expert_offsets - ) - - def cutlass_scaled_fp4_mm( a: torch.Tensor, b: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 4a0b4e82c1b39..9281780fca478 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, ) -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize, _resize_cache +from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -896,162 +896,6 @@ def cutlass_moe_fp4( ) -def _valid_cutlass_block_scaled_grouped_gemm( - w1: torch.Tensor, - w2: torch.Tensor, - inplace: bool, - activation: str, - apply_router_weight_on_input: bool, - expert_map: torch.Tensor | None, -) -> bool: - def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): - return N % 128 == 0 and K % 128 == 0 - - _, K, N = w2.size() - if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K): - logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: unaligned problem size. " - "N: %s, K: %s", - N, - K, - ) - return False - - if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn: - logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s). " - "w1.dtype: %s, w2.dtype: %s", - w1.dtype, - w2.dtype, - ) - return False - - if expert_map is not None: - logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: expert_parallel is not supported." - ) - return False - - if activation != "silu": - logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: only activation silu is supported." - ) - return False - - if apply_router_weight_on_input: - logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled:" - " apply_router_weight_on_input is not supported." - ) - return False - - if inplace: - logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: inplace is not supported." - ) - return False - - return True - - -# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8. -def run_cutlass_block_scaled_fused_experts( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - w1_q = w1.transpose(1, 2) - w2_q = w2.transpose(1, 2) - w1_scale = w1_scale.transpose(1, 2) - w2_scale = w2_scale.transpose(1, 2) - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert a.shape[0] == topk_ids.shape[0], ( - "a and topk_ids must have the same batch size" - ) - assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" - assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch" - assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - - out_dtype = a.dtype - num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - - topk = topk_ids.size(1) - - a_q, a1_scale = _fp8_quantize( - a, A_scale=None, per_act_token=False, block_shape=[128, 128] - ) - device = a_q.device - - expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device=device) - problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) - problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) - - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - - ops.get_cutlass_moe_mm_data( - topk_ids, - expert_offsets, - problem_sizes1, - problem_sizes2, - a_map, - c_map, - num_experts, - n, - k, - ) - - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] - - c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device) - c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device) - - ops.cutlass_blockwise_scaled_grouped_mm( - c1, - rep_a_q, - w1_q, - rep_a1_scales, - w1_scale, - problem_sizes1, - expert_offsets[:-1], - ) - - intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) - torch.ops._C.silu_and_mul(intermediate, c1) - - intermediate_q, a2_scale = _fp8_quantize( - intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128] - ) - - ops.cutlass_blockwise_scaled_grouped_mm( - c2, - intermediate_q, - w2_q, - a2_scale, - w2_scale, - problem_sizes2, - expert_offsets[:-1], - ) - - return ( - c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) - ).sum(dim=1) - - # W4A8 def run_cutlass_moe_w4a8_fp8( output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 37f8e7780f999..c8d80ae023d43 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -25,10 +25,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, _get_config_dtype_str, ) -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - _valid_cutlass_block_scaled_grouped_gemm, - run_cutlass_block_scaled_fused_experts, -) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8, @@ -1678,11 +1674,9 @@ def fused_experts( expert_map: torch.Tensor | None = None, quant_config: FusedMoEQuantConfig | None = None, allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - use_fp8_w8a8 = quant_config.use_fp8_w8a8 # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. @@ -1712,23 +1706,6 @@ def fused_experts( a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif ( - allow_cutlass_block_scaled_grouped_gemm - and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm( - w1, w2, inplace, activation, apply_router_weight_on_input, expert_map - ) - ): - assert quant_config is not None - return run_cutlass_block_scaled_fused_experts( - a=hidden_states, - w1=w1, - w2=w2, - w1_scale=quant_config.w1_scale, - w2_scale=quant_config.w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ec3fc5ace17d8..78685538ea1b3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -118,9 +118,8 @@ class Fp8MoeBackend(Enum): FLASHINFER_TRTLLM = 1 FLASHINFER_CUTLASS = 2 DEEPGEMM = 3 - CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4 - MARLIN = 5 - TRITON = 6 + MARLIN = 4 + TRITON = 5 def get_fp8_moe_backend( @@ -191,17 +190,6 @@ def get_fp8_moe_backend( logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local") return Fp8MoeBackend.DEEPGEMM - # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights - if ( - current_platform.is_cuda() - and current_platform.is_device_capability_family(100) - and block_quant - ): - logger.info_once( - "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE", scope="local" - ) - return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM - # default to Triton logger.info_once("Using Triton backend for FP8 MoE") return Fp8MoeBackend.TRITON @@ -752,9 +740,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM - self.allow_cutlass_block_scaled_grouped_gemm = ( - self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM - ) def create_weights( self, @@ -1316,9 +1301,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_map=layer.expert_map, quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm - ), ) if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: