From 25ed6738d411a5da4f0f8daee7365443baa97ef5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 21 May 2025 21:17:52 +0000 Subject: [PATCH] wip Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 118 ++++++++++++++++-- .../layers/fused_moe/fused_batched_moe.py | 84 +++++++++++-- 2 files changed, 182 insertions(+), 20 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 3f38b9fbcb3ca..8c143c808cf86 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import pytest import torch import triton.language as tl +from typing import Optional from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel) @@ -57,20 +58,94 @@ class BatchedMMTensors: return BatchedMMTensors(A, B, C, num_expert_tokens) -def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - num_expert_tokens: torch.Tensor) -> torch.Tensor: +def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, + As: torch.Tensor, Bs: torch.Tensor, + block_size, + output_dtype = torch.bfloat16): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32).contiguous() + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], ( + f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}") + assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}" + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def ref_impl( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + num_expert_tokens: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + block_shape: Optional[list[int]], +) -> torch.Tensor: num_expert_tokens_cpu = num_expert_tokens.clone() num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") num_experts = num_expert_tokens.size(0) for e in range(num_experts): num_tokens = num_expert_tokens_cpu[e] - C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + if A.dtype == torch.torch.float8_e4m3fn: + C[e, :, :] = native_w8a8_block_matmul(A[e, :, :], + B[e].transpose(0, 1), + A_scale, + B_scale, + [1,1])#block_shape) + else: + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) return C - @pytest.mark.parametrize("num_experts", [16, 32]) @pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512]) @@ -94,6 +169,20 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32 }[test_output.dtype] + + use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn + block_shape = [16, 16, 32] # 16 for k if not fp8 + + print(f"tensors.A {tensors.A.shape}") + print(f"tensors.B {tensors.B.shape}") + + if use_fp8_w8a8: + A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device) + B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device) + else: + A_scale = None + B_scale = None + invoke_moe_batched_triton_kernel( tensors.A, tensors.B, @@ -101,21 +190,26 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, tensors.num_expert_tokens, compute_tl_dtype, # Quantization data - None, - None, + A_scale, + B_scale, None, # Quantization schemes - dtype == torch.torch.float8_e4m3fn, + use_fp8_w8a8, False, False, config={ - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16 + "BLOCK_SIZE_M": block_shape[0], + "BLOCK_SIZE_N": block_shape[1], + "BLOCK_SIZE_K": block_shape[2], }) - ref_output = ref_impl(tensors.A, tensors.B, ref_output, - tensors.num_expert_tokens) + ref_output = ref_impl(tensors.A, + tensors.B, + ref_output, + tensors.num_expert_tokens, + A_scale, + B_scale, + block_shape[-2:]) rtol, atol = { torch.torch.float8_e4m3fn: (6e-2, 6e-2), diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 50c9e21af6a04..8695bd357d315 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -316,8 +316,8 @@ def invoke_moe_batched_triton_kernel( expert_num_tokens: torch.Tensor, # [E] compute_type: tl.dtype, # Quantization data - A_scale: torch.Tensor, - B_scale: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], B_zp: torch.Tensor, # Quantization schemes use_fp8_w8a8: bool, @@ -500,15 +500,16 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): block_m: Optional[int] = None, ): super().__init__() - assert block_shape is None + #assert block_shape is None assert block_m is None - assert not use_fp8_w8a8, "NYI" assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size + self.use_fp8_w8a8 = use_fp8_w8a8 + self.block_shape = block_shape def workspace_shapes( self, @@ -528,6 +529,66 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2 = max_num_tokens * num_dp * N return (workspace13, workspace2, a.dtype) + def native_w8a8_block_matmul(A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert self.block_shape is not None and len(self.block_shape) == 2 + block_n, block_k = self.block_shape[0], self.block_shape[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + def apply( self, hidden_states: torch.Tensor, @@ -580,9 +641,15 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): else: num = int(expert_num_tokens[expert].item()) tmp = _resize_cache(workspace2, (num, N)) - input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) - self.activation(activation, tmp, input) - out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + if self.use_fp8_w8a8: + assert False # TBD + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + self.activation(activation, tmp, input) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + else: + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + self.activation(activation, tmp, input) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out @@ -732,7 +799,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache1.view(-1, N)) #qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale + # TODO (varun) : support w8a8 #assert not self.use_fp8_w8a8 if self.use_fp8_w8a8: @@ -753,6 +820,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): per_act_token, self.block_shape) else: qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2,