Fixes and updates to bench_per_token_quant_fp8 (#25591)

Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Michael Goin 2025-09-24 11:30:15 -04:00 committed by yewentao256
parent b1f9a1f46a
commit 984bfb4ba7

View File

@ -51,7 +51,7 @@ def calculate_diff(
): ):
"""Calculate the difference between Inductor and CUDA implementations.""" """Calculate the difference between Inductor and CUDA implementations."""
device = torch.device("cuda") device = torch.device("cuda")
x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device) x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device)
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
@ -59,23 +59,25 @@ def calculate_diff(
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
cuda_out, cuda_scale = quant_fp8.forward_cuda(x) cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
out_allclose = lambda o1, o2: torch.allclose( try:
o1.to(torch.float32), torch.testing.assert_close(
o2.to(torch.float32), cuda_out.to(torch.float32),
rtol=1e-3, torch_out.to(torch.float32),
atol=1e-5, rtol=1e-3,
) atol=1e-5,
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5) )
torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5)
if ( torch.testing.assert_close(
out_allclose(cuda_out, torch_out) cuda_out.to(torch.float32),
and scale_allclose(cuda_scale, torch_scale) torch_eager_out.to(torch.float32),
and out_allclose(cuda_out, torch_eager_out) rtol=1e-3,
and scale_allclose(cuda_scale, torch_eager_scale) atol=1e-5,
): )
torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5)
print("✅ All implementations match") print("✅ All implementations match")
else: except AssertionError as e:
print("❌ Implementations differ") print("❌ Implementations differ")
print(e)
configs = [] configs = []
@ -91,7 +93,7 @@ def benchmark_quantization(
): ):
device = torch.device("cuda") device = torch.device("cuda")
x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype) x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
@ -157,21 +159,21 @@ if __name__ == "__main__":
) )
parser.add_argument("-c", "--check", action="store_true") parser.add_argument("-c", "--check", action="store_true")
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
) )
parser.add_argument( parser.add_argument(
"--hidden-sizes", "--hidden-sizes",
type=int, type=int,
nargs="+", nargs="+",
default=None, default=[896, 1024, 2048, 4096, 7168],
help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)", help="Hidden sizes to benchmark",
) )
parser.add_argument( parser.add_argument(
"--batch-sizes", "--batch-sizes",
type=int, type=int,
nargs="+", nargs="+",
default=None, default=[1, 16, 128, 512, 1024],
help="Batch sizes to benchmark (default: 1,16,32,64,128)", help="Batch sizes to benchmark",
) )
parser.add_argument( parser.add_argument(
"--group-sizes", "--group-sizes",
@ -192,8 +194,8 @@ if __name__ == "__main__":
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] hidden_sizes = args.hidden_sizes
batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128] batch_sizes = args.batch_sizes
if args.group_sizes is not None: if args.group_sizes is not None:
group_shapes = [] group_shapes = []