diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index b1dda531e6b30..f225d0ecc313e 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( @dataclass class BatchedMMConfig: - dtype: torch.dtype + in_dtype: torch.dtype + out_dtype: torch.dtype num_experts: int max_tokens_per_expert: int K: int @@ -29,26 +30,25 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): - if config.dtype == torch.torch.float8_e4m3fn: - config_dtype = torch.bfloat16 + if config.in_dtype == torch.torch.float8_e4m3fn: + config_in_dtype = torch.bfloat16 else: - config_dtype = config.dtype + config_in_dtype = config.in_dtype A = torch.randn( (config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", - dtype=config_dtype) / 10 + dtype=config_in_dtype) / 10 B = torch.randn((config.num_experts, config.N, config.K), device="cuda", - dtype=config_dtype) + dtype=config_in_dtype) C = torch.zeros( (config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", - dtype=config_dtype) + dtype=config.out_dtype) - A = A.to(config.dtype) - B = B.to(config.dtype) - C = C.to(config.dtype) + A = A.to(config.in_dtype) + B = B.to(config.in_dtype) num_expert_tokens = torch.randint(low=0, high=config.max_tokens_per_expert, @@ -136,11 +136,19 @@ def ref_impl( for e in range(num_experts): num_tokens = num_expert_tokens_cpu[e] if A.dtype == torch.torch.float8_e4m3fn: - tmp = native_w8a8_block_matmul(A[e, :, :], - B[e].transpose(0, 1), - A_scale, - B_scale, - [1,1])#block_shape) + if False: + tmp = native_w8a8_block_matmul(A[e, :, :], + B[e].transpose(0, 1), + A_scale, + B_scale, + [1,1])#block_shape) + else: + import vllm._custom_ops as ops + tmp = ops.cutlass_scaled_mm(A[e, :, :], + B[e].transpose(0, 1), + A_scale, + B_scale, + C.dtype) C[e, :num_tokens, :] = tmp[:num_tokens, :] else: C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) @@ -159,14 +167,21 @@ def ref_impl( def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype): - config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) + if dtype == torch.torch.float8_e4m3fn: + in_dtype = dtype + out_dtype = torch.bfloat16 + else: + in_dtype = dtype + out_dtype = dtype + + config = BatchedMMConfig(in_dtype, out_dtype, num_experts, max_tokens_per_expert, K, N) tensors = BatchedMMTensors.make_tensors(config) test_output = tensors.C ref_output = test_output.clone() + ref_output2 = test_output.clone() compute_tl_dtype = { - torch.torch.float8_e4m3fn: tl.bfloat16, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32 @@ -175,12 +190,14 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn block_shape = [16, 16, 32] # 16 for k if not fp8 - print(f"tensors.A {tensors.A.shape}") - print(f"tensors.B {tensors.B.shape}") + #print(f"tensors.A {tensors.A.shape}") + #print(f"tensors.B {tensors.B.shape}") if use_fp8_w8a8: - A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device) - B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device) + #A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device) + #B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device) + A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device) + B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device) else: A_scale = None B_scale = None @@ -205,19 +222,29 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, "BLOCK_SIZE_K": block_shape[2], }) - ref_output = ref_impl(tensors.A, - tensors.B, + ref_output = ref_output.to(dtype=out_dtype) + ref_output = ref_impl(tensors.A.to(dtype=out_dtype), + tensors.B.to(dtype=out_dtype), ref_output, tensors.num_expert_tokens, A_scale, B_scale, block_shape[-2:]) + ref_output2 = ref_impl(tensors.A, + tensors.B, + ref_output2, + tensors.num_expert_tokens, + A_scale, + B_scale, + block_shape[-2:]) + rtol, atol = { - torch.torch.float8_e4m3fn: (6e-2, 6e-2), torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), torch.float32: (1e-2, 1e-2), }[test_output.dtype] - torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol) + if not use_fp8_w8a8: + torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)