mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 15:37:59 +08:00
Fixes and updates to bench_per_token_quant_fp8 (#25591)
Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b1f9a1f46a
commit
984bfb4ba7
@ -51,7 +51,7 @@ def calculate_diff(
|
|||||||
):
|
):
|
||||||
"""Calculate the difference between Inductor and CUDA implementations."""
|
"""Calculate the difference between Inductor and CUDA implementations."""
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device)
|
x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device)
|
||||||
|
|
||||||
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
|
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
|
||||||
|
|
||||||
@ -59,23 +59,25 @@ def calculate_diff(
|
|||||||
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
|
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
|
||||||
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
|
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
|
||||||
|
|
||||||
out_allclose = lambda o1, o2: torch.allclose(
|
try:
|
||||||
o1.to(torch.float32),
|
torch.testing.assert_close(
|
||||||
o2.to(torch.float32),
|
cuda_out.to(torch.float32),
|
||||||
rtol=1e-3,
|
torch_out.to(torch.float32),
|
||||||
atol=1e-5,
|
rtol=1e-3,
|
||||||
)
|
atol=1e-5,
|
||||||
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5)
|
)
|
||||||
|
torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5)
|
||||||
if (
|
torch.testing.assert_close(
|
||||||
out_allclose(cuda_out, torch_out)
|
cuda_out.to(torch.float32),
|
||||||
and scale_allclose(cuda_scale, torch_scale)
|
torch_eager_out.to(torch.float32),
|
||||||
and out_allclose(cuda_out, torch_eager_out)
|
rtol=1e-3,
|
||||||
and scale_allclose(cuda_scale, torch_eager_scale)
|
atol=1e-5,
|
||||||
):
|
)
|
||||||
|
torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5)
|
||||||
print("✅ All implementations match")
|
print("✅ All implementations match")
|
||||||
else:
|
except AssertionError as e:
|
||||||
print("❌ Implementations differ")
|
print("❌ Implementations differ")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
configs = []
|
configs = []
|
||||||
@ -91,7 +93,7 @@ def benchmark_quantization(
|
|||||||
):
|
):
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
||||||
x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype)
|
x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
|
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
|
||||||
@ -157,21 +159,21 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument("-c", "--check", action="store_true")
|
parser.add_argument("-c", "--check", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hidden-sizes",
|
"--hidden-sizes",
|
||||||
type=int,
|
type=int,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=None,
|
default=[896, 1024, 2048, 4096, 7168],
|
||||||
help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)",
|
help="Hidden sizes to benchmark",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch-sizes",
|
"--batch-sizes",
|
||||||
type=int,
|
type=int,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=None,
|
default=[1, 16, 128, 512, 1024],
|
||||||
help="Batch sizes to benchmark (default: 1,16,32,64,128)",
|
help="Batch sizes to benchmark",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group-sizes",
|
"--group-sizes",
|
||||||
@ -192,8 +194,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||||
|
|
||||||
hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
hidden_sizes = args.hidden_sizes
|
||||||
batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128]
|
batch_sizes = args.batch_sizes
|
||||||
|
|
||||||
if args.group_sizes is not None:
|
if args.group_sizes is not None:
|
||||||
group_shapes = []
|
group_shapes = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user