diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index b15254277cc75..2c3e32ca1d7c8 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -1,30 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from typing import Optional import pytest import torch import triton.language as tl -from typing import Optional import vllm._custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - invoke_moe_batched_triton_kernel, - BatchedExperts, - BatchedPrepareAndFinalize, - BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, - get_default_config) + BatchedPrepareAndFinalize, BatchedTritonExperts, + invoke_moe_batched_triton_kernel) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + per_token_group_quant_fp8) from vllm.platforms import current_platform from vllm.utils import round_up - NUM_EXPERTS = [8, 64] TOP_KS = [1, 2, 6] @@ -80,10 +76,12 @@ class BatchedMMTensors: return BatchedMMTensors(A, B, C, num_expert_tokens) -def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, - As: torch.Tensor, Bs: torch.Tensor, +def native_w8a8_block_matmul(A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, block_size, - output_dtype = torch.bfloat16): + output_dtype=torch.bfloat16): """This function performs matrix multiplication with block-wise quantization using native torch. It is agnostic to the input data type and can be used for both int8 and @@ -160,16 +158,11 @@ def ref_impl( if A.dtype == torch.torch.float8_e4m3fn: if False: tmp = native_w8a8_block_matmul(A[e, :, :], - B[e].transpose(0, 1), - A_scale, - B_scale, - block_shape) + B[e].transpose(0, 1), A_scale, + B_scale, block_shape) else: - tmp = ops.cutlass_scaled_mm(A[e, :, :], - B[e].transpose(0, 1), - A_scale, - B_scale, - torch.bfloat16) + tmp = ops.cutlass_scaled_mm(A[e, :, :], B[e].transpose(0, 1), + A_scale, B_scale, torch.bfloat16) C[e, :num_tokens, :] = tmp[:num_tokens, :] else: C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) @@ -195,7 +188,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, in_dtype = dtype out_dtype = dtype - config = BatchedMMConfig(in_dtype, out_dtype, num_experts, max_tokens_per_expert, K, N) + config = BatchedMMConfig(in_dtype, out_dtype, num_experts, + max_tokens_per_expert, K, N) tensors = BatchedMMTensors.make_tensors(config) test_output = tensors.C @@ -209,7 +203,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, }[test_output.dtype] use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn - block_shape = [16, 16, 32] # 16 for k if not fp8 + block_shape = [16, 16, 32] # 16 for k if not fp8 #print(f"tensors.A {tensors.A.shape}") #print(f"tensors.B {tensors.B.shape}") @@ -250,19 +244,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ref_output = ref_output.to(dtype=out_dtype) ref_output = ref_impl(tensors.A.to(dtype=out_dtype), - tensors.B.to(dtype=out_dtype), - ref_output, - tensors.num_expert_tokens, - A_scale, - B_scale, + tensors.B.to(dtype=out_dtype), ref_output, + tensors.num_expert_tokens, A_scale, B_scale, block_shape[-2:]) - ref_output2 = ref_impl(tensors.A, - tensors.B, - ref_output2, - tensors.num_expert_tokens, - A_scale, - B_scale, + ref_output2 = ref_impl(tensors.A, tensors.B, ref_output2, + tensors.num_expert_tokens, A_scale, B_scale, block_shape[-2:]) rtol, atol = { @@ -286,11 +273,17 @@ def batched_moe( use_fp8_w8a8: bool = False, block_shape: Optional[list[int]] = None, ) -> torch.Tensor: - max_num_tokens = round_up(a.shape[0], 64) # ? + max_num_tokens = round_up(a.shape[0], 64) # ? fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, world_size=1, dp_size=1, rank=0, use_fp8_w8a8=use_fp8_w8a8, + BatchedPrepareAndFinalize(max_num_tokens, + world_size=1, + dp_size=1, + rank=0, + use_fp8_w8a8=use_fp8_w8a8, block_shape=block_shape), - BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1, + BatchedTritonExperts(max_num_tokens=max_num_tokens, + dp_size=1, + world_size=1, use_fp8_w8a8=use_fp8_w8a8, block_shape=block_shape)) @@ -322,11 +315,13 @@ def torch_moe2( if use_fp8_w8a8: a, a_scale = per_token_group_quant_fp8(a, block_shape[1]) - #print(f"a_scale {a_scale.shape}") else: a_scale = None - out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + out = torch.zeros(M * topk, + w2.shape[1], + dtype=torch.bfloat16, + device=a.device) num_experts = w1.shape[0] for i in range(num_experts): mask = (topk_ids == i).view(-1) @@ -341,11 +336,8 @@ def torch_moe2( # a_scale[mask], # w1_scale[i], # torch.bfloat16) - tmp1 = native_w8a8_block_matmul(a[mask], - w1[i], - a_scale[mask], - w1_scale[i], - block_shape, + tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], + w1_scale[i], block_shape, torch.bfloat16) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1]) @@ -355,11 +347,8 @@ def torch_moe2( # b_scale, # w2_scale[i], # torch.bfloat16) - out[mask] = native_w8a8_block_matmul(tmp2, - w2[i], - b_scale, - w2_scale[i], - block_shape, + out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, + w2_scale[i], block_shape, torch.bfloat16) return (out.view(M, -1, w2.shape[1]) * @@ -406,23 +395,21 @@ def test_fused_moe_batched_experts( factor_for_scale = 1e-2 w1_s = torch.rand( - (e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale + (e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, + device="cuda") * factor_for_scale w2_s = torch.rand( - (e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale + (e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, + device="cuda") * factor_for_scale else: w1_s = None w2_s = None with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape) - batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape) - # batched_output = batched_moe(a, - # w1.to(torch.bfloat16), - # w2.to(torch.bfloat16), - # topk_weight, topk_ids, - # w1_s, w2_s, False, - # block_shape) + baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, + w2_s, use_fp8_w8a8, block_shape) + batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, + w2_s, use_fp8_w8a8, block_shape) torch.testing.assert_close(baseline_output, batched_output, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 1f041730815c6..b1289ae0f53cd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -9,47 +9,44 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache, - cdiv) @triton.jit def moe_mmk( - a_ptrs, - b_ptrs, - K, - expert_id, - a_scale_ptr, - b_scale_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ak, - stride_bk, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Offsets and masks - offs_m, - offs_n, - mask_m, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - compute_type: tl.constexpr, - use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr -): + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): offs_k = tl.arange(0, BLOCK_K) if use_w8a16: @@ -313,22 +310,21 @@ def batched_triton_kernel( def invoke_moe_batched_triton_kernel( - A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, K, N] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - block_shape: Optional[list[int]] = None -): + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) @@ -392,8 +388,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): that the PPLX dispatch/combine kernels use. """ - def __init__(self, max_num_tokens: Optional[int], world_size: int, - dp_size: int, rank: int, use_fp8_w8a8: bool = False, + def __init__(self, + max_num_tokens: Optional[int], + world_size: int, + dp_size: int, + rank: int, + use_fp8_w8a8: bool = False, block_shape: Optional[list[int]] = None): super().__init__() self.world_size = world_size @@ -463,13 +463,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() rows = torch.count_nonzero(topks.flatten()) - rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert if self.use_fp8_w8a8: # TODO: use _fp8_quantize - b_a1[idx, :rows, :], tmp_scale = per_token_group_quant_fp8(rhs, block_k) - b_a1_scale[idx, :rows] = tmp_scale # inline? + b_a1[idx, :rows, :], b_a1_scale[ + idx, :rows] = per_token_group_quant_fp8(rhs, block_k) else: b_a1[idx, :rows, :] = rhs @@ -549,7 +548,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): num_dp = self.world_size // self.dp_size max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens - #print(f"WORKSPACE {max_num_tokens} {num_dp}") workspace13 = num_experts * max_num_tokens * num_dp * K workspace2 = max_num_tokens * num_dp * N return (workspace13, workspace2, a.dtype) @@ -607,9 +605,10 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): num = int(expert_num_tokens[expert].item()) tmp = _resize_cache(workspace2, (num, N)) if self.use_fp8_w8a8: - assert False # TBD + assert False # TBD else: - input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + input = hidden_states[expert, :num, :] @ w1[expert].transpose( + 0, 1) self.activation(activation, tmp, input) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) @@ -768,12 +767,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): #assert not self.use_fp8_w8a8 if self.use_fp8_w8a8: per_act_token = False - qintermediate_cache2 = torch.zeros_like(intermediate_cache2, + # TODO: reuse? + qintermediate_cache2 = torch.empty_like(intermediate_cache2, dtype=torch.float8_e4m3fn) block_n = self.block_shape[0] n_tiles = ((N // 2) + block_n - 1) // block_n scale_shape = (E, num_tokens, n_tiles) - a2q_scale = torch.zeros(scale_shape, + a2q_scale = torch.empty(scale_shape, dtype=torch.float32, device=hidden_states.device) for e in range(E): @@ -783,10 +783,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): # intermediate_cache2[e], # a2_scale[e] if a2_scale is not None else None, # per_act_token, self.block_shape) - qintermediate_cache2[e, :num_tokens, :], tmp_scale = per_token_group_quant_fp8( - intermediate_cache2[e, :num_tokens], block_n) - #print(a2q_scale[e, :tmp_scale.shape[0]].shape) - #print(tmp_scale.shape) + qintermediate_cache2[ + e, : + num_tokens, :], tmp_scale = per_token_group_quant_fp8( + intermediate_cache2[e, :num_tokens], block_n) a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale else: qintermediate_cache2 = intermediate_cache2 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0f9058a3feedd..11210aeaebaf5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1240,8 +1240,6 @@ class FusedMoE(torch.nn.Module): if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) - assert topk_ids.dtype == indices_type - return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: