From e2b85cf86a522e734a38b1d0314cfe9625003ef9 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sun, 16 Jun 2024 23:48:06 -0700 Subject: [PATCH] Fix w8a8 benchmark and add Llama-3-8B (#5562) --- .../cutlass_benchmarks/w8a8_benchmarks.py | 21 ++++++++++++------- .../cutlass_benchmarks/weight_shapes.py | 6 ++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 182105f0b33f..523e970c2c9b 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int, # impl -def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, +def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, scale_b: torch.tensor, out_dtype: torch.dtype) -> torch.tensor: return torch.mm(a, b) @@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, timers.append( bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_i8_impl, + torch.bfloat16, label, sub_label, pytorch_mm_impl, "pytorch_bf16_bf16_bf16_matmul-no-scales")) # cutlass impl @@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, timers = [] + # pytorch impl w. bf16 + timers.append( + bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, + torch.bfloat16, label, sub_label, pytorch_mm_impl, + "pytorch_bf16_bf16_bf16_matmul-no-scales")) + # pytorch impl: bf16 output, without fp8 fast accum timers.append( bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, @@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # cutlass impl: bf16 output timers.append( - bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), - torch.bfloat16, label, sub_label, cutlass_impl, - "cutlass_fp8_fp8_bf16_scaled_mm")) + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm")) # cutlass impl: fp16 output timers.append( - bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), - torch.float16, label, sub_label, cutlass_impl, - "cutlass_fp8_fp8_fp16_scaled_mm")) + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm")) return timers diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 7ad4a53d376b..25ec9d602862 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -22,6 +22,12 @@ WEIGHT_SHAPES = { ([4096, 22016], 1), ([11008, 4096], 0), ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], "meta-llama/Llama-2-13b-hf": [ ([5120, 15360], 1), ([5120, 5120], 0),