diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index d5fb35324de5f..a3102d0cb8806 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -276,7 +276,7 @@ def batched_moe( rank=0, qtype=qtype, block_shape=block_shape, - per_act_token=False), + per_act_token=per_act_token), BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1, @@ -327,22 +327,13 @@ def torch_moe2( tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) else: - #tmp1 = ops.cutlass_scaled_mm(a[mask], - # w1[i].transpose(0, 1), - # a_scale[mask], - # w1_scale[i], - # torch.bfloat16) tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, torch.bfloat16) + tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1]) - # out[mask] = ops.cutlass_scaled_mm(tmp2, - # w2[i].transpose(0, 1), - # b_scale, - # w2_scale[i], - # torch.bfloat16) out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, w2_scale[i], block_shape, torch.bfloat16) @@ -403,10 +394,10 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, - w2_s, use_fp8_w8a8, block_shape) batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, qtype, block_shape) + baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, + w2_s, use_fp8_w8a8, block_shape) torch.testing.assert_close(baseline_output, batched_output, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 3cae2b0ecfdec..44fb82a5f427a 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -33,7 +33,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, get_default_config) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.platforms import current_platform +from vllm.utils import round_up PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), (222, 2048, 1024)] @@ -280,6 +283,70 @@ def batched_moe( return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) +def native_w8a8_block_matmul(A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size, + output_dtype=torch.bfloat16): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32).contiguous() + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], ( + f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}") + assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}" + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + # Note: same as torch_moe but with fused_topk factored out. def torch_moe2( a: torch.Tensor, @@ -287,17 +354,44 @@ def torch_moe2( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + block_shape: Optional[list[int]] = None, ) -> torch.Tensor: M, K = a.shape topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + if use_fp8_w8a8: + a, a_scale = per_token_group_quant_fp8(a, block_shape[1]) + else: + a_scale = None + + out = torch.zeros(M * topk, + w2.shape[1], + dtype=torch.bfloat16, + device=a.device) num_experts = w1.shape[0] for i in range(num_experts): mask = (topk_ids == i).view(-1) if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + if not use_fp8_w8a8: + tmp1 = a[mask] @ w1[i].transpose(0, 1) + tmp2 = SiluAndMul()(tmp1) + out[mask] = tmp2 @ w2[i].transpose(0, 1) + else: + tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], + w1_scale[i], block_shape, + torch.bfloat16) + + tmp2 = SiluAndMul()(tmp1) + tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1]) + + out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, + w2_scale[i], block_shape, + torch.bfloat16) return (out.view(M, -1, w2.shape[1]) * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @@ -502,6 +596,10 @@ def pplx_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + qtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, use_compile: bool = True, use_cudagraphs: bool = True, ) -> torch.Tensor: @@ -511,9 +609,20 @@ def pplx_moe( device = torch.device("cuda", rank) hidden_dim = a.shape[1] num_experts = w1.shape[0] - block_size = 128 + block_size = block_shape[1] if block_shape is not None else 128 topk = topk_ids.shape[1] - max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + if block_shape: + max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), block_shape[0]) + else: + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + if qtype is not None: + a_dtype = qtype + #print(f"SCALE BYTES {hidden_dim} {block_size} {((hidden_dim + block_size - 1) * torch.float32.itemsize) // block_size}") + scale_bytes = 16 + else: + a_dtype = a.dtype + scale_bytes = 0 ata = AllToAll.internode( max_num_tokens=max_num_tokens, @@ -523,10 +632,8 @@ def pplx_moe( world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else - ((hidden_dim + block_size - 1) // block_size * - torch.float32.itemsize)), + hidden_dim_bytes=hidden_dim * a_dtype.itemsize, + hidden_dim_scale_bytes=scale_bytes, ) topk_ids = topk_ids.to(dtype=torch.uint32) @@ -537,11 +644,15 @@ def pplx_moe( world_size, rank, dp_size, + quant_dtype=qtype, + block_shape=block_shape, ) - experts = BatchedTritonExperts(max_num_tokens=a.shape[0], + experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, world_size=world_size, - dp_size=dp_size) + dp_size=dp_size, + use_fp8_w8a8=qtype==torch.float8_e4m3fn, + block_shape=block_shape) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -557,7 +668,14 @@ def pplx_moe( w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) - if use_compile: + if w1_scale is not None: + w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) + w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) + else: + w1_scale_chunk = None + w2_scale_chunk = None + + if False and use_compile: _fused_experts = torch.compile(fused_experts, backend='inductor', fullgraph=True) @@ -569,9 +687,11 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, global_num_experts=num_experts) - if use_cudagraphs: + if False and use_cudagraphs: #XXXXXXXXXXXX out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() @@ -581,6 +701,8 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, global_num_experts=num_experts) torch.cuda.synchronize() @@ -643,6 +765,10 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, + w1_s: Optional[torch.Tensor] = None, + w2_s: Optional[torch.Tensor] = None, + qtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, ): uid = nvshmem_get_unique_id( ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() @@ -654,11 +780,20 @@ def _pplx_moe( moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) + use_fp8_w8a8 = qtype == torch.float8_e4m3fn + + device = torch.device("cuda", pgi.rank) + a = a.to(device) + w1 = w1.to(device) + w2 = w2.to(device) + w1_s = w1_s.to(device) if w1_s is not None else None + w2_s = w2_s.to(device) if w2_s is not None else None + with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape) pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2, - topk_weight, topk_ids) + topk_weight, topk_ids, w1_s, w2_s, qtype, block_shape) # TODO (bnell): fix + re-enable #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, # topk_ids) @@ -675,7 +810,7 @@ def _pplx_moe( @pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @requires_pplx def test_pplx_moe( @@ -688,9 +823,40 @@ def test_pplx_moe( current_platform.seed_everything(7) m, n, k = mnk world_size, dp_size = world_dp_size - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10 + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) + use_fp8_w8a8 = dtype == torch.float8_e4m3fn + + if use_fp8_w8a8: + block_shape = [128, 128] + quant_type = torch.float8_e4m3fn + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (2 * n + block_n - 1) // block_n + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + k_tiles_w2 = (n + block_k - 1) // block_k + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype) + w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype) + + factor_for_scale = 1e-2 + w1_s = torch.rand( + (e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, + device="cuda") * factor_for_scale + w2_s = torch.rand( + (e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, + device="cuda") * factor_for_scale + else: + block_shape = None + quant_type = None + w1_s = None + w2_s = None + + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, w1_s, w2_s, quant_type, block_shape) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7d42894ffe1a1..f0fea396cc8de 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -457,6 +457,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): dtype=torch.float32, device=a1.device) else: + assert a1_scale is None b_a1_scale = None first_expert = num_local_experts * self.rank @@ -782,8 +783,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): (E, num_tokens, N // 2)) intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) - assert not self.use_fp8_w8a8 or a1q_scale is not None - # MM1 invoke_moe_batched_triton_kernel(A=hidden_states, B=w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index feefd9522e730..ff6ecffc7663a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,6 +8,9 @@ from typing import Callable, Optional import torch import torch.nn.functional as F +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -238,6 +241,18 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" +def get_quant_config_input_activations( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') + and "Linear" in quant_config.target_scheme_map and + "input_activations" in quant_config.target_scheme_map["Linear"]): + return quant_config.target_scheme_map["Linear"].get( + "input_activations") + else: + return None + + class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod @@ -262,12 +277,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, # For blocked per token: set to # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( - (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + hidden_dim_scale_bytes=(0 if moe.quant_dtype.itemsize != 1 else ( + ((moe.hidden_dim + moe.block_size - 1) // moe.block_size) * torch.float32.itemsize)), ) @@ -285,7 +300,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): rank=all2all_manager.rank, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, - quant_dtype=moe.in_dtype, + quant_dtype=moe.quant_dtype, ) if prepare_finalize is not None: @@ -774,6 +789,17 @@ class FusedMoE(torch.nn.Module): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + quant_dtype = vllm_config.model_config.dtype + if quant_config is not None: + input_activations = get_quant_config_input_activations( + quant_config) + if (input_activations is not None + and input_activations.num_bits == 8): + if input_activations.type == QuantizationType.FLOAT: + quant_dtype = torch.float8_e4m3fn + elif input_activations.type == QuantizationType.INT: + quant_dtype = torch.int8 + moe = MoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, @@ -781,6 +807,7 @@ class FusedMoE(torch.nn.Module): num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=vllm_config.model_config.dtype, + quant_dtype=quant_dtype, max_num_tokens=MOE_DP_CHUNK_SIZE, ) self.moe_config = moe @@ -822,12 +849,12 @@ class FusedMoE(torch.nn.Module): if self.moe_parallel_config.use_pplx_kernels: self.batched_hidden_states = torch.zeros( (MOE_DP_CHUNK_SIZE, self.hidden_size), - dtype=moe.in_dtype, + dtype=vllm_config.model_config.dtype, device=torch.cuda.current_device()) self.batched_router_logits = torch.zeros( (MOE_DP_CHUNK_SIZE, self.global_num_experts), - dtype=moe.in_dtype, + dtype=vllm_config.model_config.dtype, device=torch.cuda.current_device()) @property