From 984bfb4ba7ede98006cce9c67f86d1606c33ae82 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 24 Sep 2025 11:30:15 -0400 Subject: [PATCH] Fixes and updates to bench_per_token_quant_fp8 (#25591) Signed-off-by: Michael Goin Signed-off-by: yewentao256 --- .../kernels/bench_per_token_quant_fp8.py | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 9170361e974b6..e08e5680c191e 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -51,7 +51,7 @@ def calculate_diff( ): """Calculate the difference between Inductor and CUDA implementations.""" device = torch.device("cuda") - x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device) quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) @@ -59,23 +59,25 @@ def calculate_diff( torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) cuda_out, cuda_scale = quant_fp8.forward_cuda(x) - out_allclose = lambda o1, o2: torch.allclose( - o1.to(torch.float32), - o2.to(torch.float32), - rtol=1e-3, - atol=1e-5, - ) - scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5) - - if ( - out_allclose(cuda_out, torch_out) - and scale_allclose(cuda_scale, torch_scale) - and out_allclose(cuda_out, torch_eager_out) - and scale_allclose(cuda_scale, torch_eager_scale) - ): + try: + torch.testing.assert_close( + cuda_out.to(torch.float32), + torch_out.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5) + torch.testing.assert_close( + cuda_out.to(torch.float32), + torch_eager_out.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5) print("✅ All implementations match") - else: + except AssertionError as e: print("❌ Implementations differ") + print(e) configs = [] @@ -91,7 +93,7 @@ def benchmark_quantization( ): device = torch.device("cuda") - x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype) + x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) @@ -157,21 +159,21 @@ if __name__ == "__main__": ) parser.add_argument("-c", "--check", action="store_true") parser.add_argument( - "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" ) parser.add_argument( "--hidden-sizes", type=int, nargs="+", - default=None, - help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)", + default=[896, 1024, 2048, 4096, 7168], + help="Hidden sizes to benchmark", ) parser.add_argument( "--batch-sizes", type=int, nargs="+", - default=None, - help="Batch sizes to benchmark (default: 1,16,32,64,128)", + default=[1, 16, 128, 512, 1024], + help="Batch sizes to benchmark", ) parser.add_argument( "--group-sizes", @@ -192,8 +194,8 @@ if __name__ == "__main__": dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] - hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] - batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128] + hidden_sizes = args.hidden_sizes + batch_sizes = args.batch_sizes if args.group_sizes is not None: group_shapes = []