mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 09:25:29 +08:00
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
245 lines
7.0 KiB
Python
245 lines
7.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from itertools import product
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.utils.benchmark as TBenchmark
|
|
from torch.utils.benchmark import Measurement as TMeasurement
|
|
|
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
_per_token_group_quant_fp8_colmajor,
|
|
silu_mul_per_token_group_quant_fp8_colmajor,
|
|
)
|
|
from vllm.triton_utils import triton
|
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
|
|
|
from .utils import ArgPool, Bench, CudaGraphBenchParams
|
|
|
|
GROUP_SIZE = 128
|
|
FLOAT8_T = torch.float8_e4m3fn
|
|
|
|
|
|
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
|
|
print(
|
|
f"Note : The timings reported above is for {cuda_graph_nops} "
|
|
"consecutive invocations of the benchmarking functions. "
|
|
f"Please divide by {cuda_graph_nops} for single invocation "
|
|
"timings."
|
|
)
|
|
compare = TBenchmark.Compare(timers)
|
|
compare.print()
|
|
|
|
|
|
class ImplType(Enum):
|
|
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
|
|
REFERENCE = 2
|
|
|
|
def get_impl(self):
|
|
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
|
return silu_mul_per_token_group_quant_fp8_colmajor
|
|
elif self == ImplType.REFERENCE:
|
|
return reference
|
|
raise ValueError(f"Unrecognized ImplType {self}")
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkTensors:
|
|
input: torch.Tensor
|
|
output: torch.Tensor
|
|
|
|
# Reference act output tensor
|
|
ref_act_out: torch.Tensor
|
|
ref_quant_out: torch.Tensor
|
|
|
|
@staticmethod
|
|
def make(T: int, N: int) -> "BenchmarkTensors":
|
|
assert T % GROUP_SIZE == 0
|
|
assert N % (GROUP_SIZE * 2) == 0
|
|
|
|
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
|
|
|
# silu_mul_per_token_group_quant_fp8_colmajor output.
|
|
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
|
|
FLOAT8_T
|
|
)
|
|
|
|
# reference output.
|
|
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
|
ref_quant_out = torch.empty(
|
|
(T, N // 2), dtype=torch.bfloat16, device="cuda"
|
|
).to(FLOAT8_T)
|
|
|
|
return BenchmarkTensors(
|
|
input=input,
|
|
output=output,
|
|
ref_act_out=ref_act_out,
|
|
ref_quant_out=ref_quant_out,
|
|
)
|
|
|
|
@property
|
|
def T(self):
|
|
return self.input.size(0)
|
|
|
|
@property
|
|
def N(self):
|
|
return self.input.size(1)
|
|
|
|
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
|
|
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
|
return {
|
|
"input": self.input,
|
|
"output": self.output,
|
|
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
|
}
|
|
elif impl_type == ImplType.REFERENCE:
|
|
return {
|
|
"input": self.input,
|
|
"act_out": self.ref_act_out,
|
|
"quant_out": self.ref_quant_out,
|
|
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
|
}
|
|
raise ValueError(f"Unrecognized impl_type {impl_type}")
|
|
|
|
|
|
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
|
|
"""
|
|
Reference triton quant kernel from,
|
|
vllm.model_executor.layers.quantization.utils.fp8_utils
|
|
"""
|
|
assert quant_out.size() == x.size()
|
|
# Allocate the scale tensor column-major format.
|
|
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
|
x_q = quant_out
|
|
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
|
|
|
M = x.numel() // GROUP_SIZE
|
|
N = GROUP_SIZE
|
|
BLOCK = triton.next_power_of_2(N)
|
|
# heuristics for number of warps
|
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
|
num_stages = 1
|
|
|
|
finfo = torch.finfo(FLOAT8_T)
|
|
fp8_min = finfo.min
|
|
fp8_max = finfo.max
|
|
|
|
_per_token_group_quant_fp8_colmajor[(M,)](
|
|
x,
|
|
x_q,
|
|
x_s,
|
|
GROUP_SIZE,
|
|
x.shape[1],
|
|
x.stride(0),
|
|
x_s.stride(1),
|
|
eps=1e-10,
|
|
fp8_min=fp8_min,
|
|
fp8_max=fp8_max,
|
|
use_ue8m0=use_ue8m0,
|
|
BLOCK=BLOCK,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
)
|
|
return x_q, x_s
|
|
|
|
|
|
def reference(
|
|
input: torch.Tensor,
|
|
act_out: torch.Tensor,
|
|
quant_out: torch.Tensor,
|
|
use_ue8m0: bool,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
torch.ops._C.silu_and_mul(act_out, input)
|
|
return reference_quant(act_out, quant_out, use_ue8m0)
|
|
|
|
|
|
def bench_impl(
|
|
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
|
|
) -> TMeasurement:
|
|
T = bench_tensors[0].T
|
|
N = bench_tensors[0].N
|
|
|
|
arg_pool_size = len(bench_tensors)
|
|
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
|
|
|
|
# warmup
|
|
for kwargs in kwargs_list:
|
|
impl_type.get_impl()(**kwargs)
|
|
torch.cuda.synchronize()
|
|
|
|
# Merge into a single kwargs and qualify arguments as ArgPool
|
|
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
|
|
for _kwargs in kwargs_list:
|
|
for k, v in _kwargs.items():
|
|
kwargs[k].values.append(v)
|
|
|
|
cuda_graph_params = None
|
|
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
|
|
timer = None
|
|
with Bench(
|
|
cuda_graph_params,
|
|
"silu-mul-quant",
|
|
f"num_tokens={T}, N={N}",
|
|
impl_type.name,
|
|
impl_type.get_impl(),
|
|
**kwargs,
|
|
) as bench:
|
|
timer = bench.run()
|
|
return timer
|
|
|
|
|
|
def test_correctness(T: int, N: int):
|
|
print(f"Testing num_tokens={T}, N={N} ...")
|
|
|
|
bench_tensor = BenchmarkTensors.make(T, N)
|
|
|
|
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
|
|
|
|
# reference output
|
|
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
|
|
|
|
# test ouptut
|
|
out_q, out_s = output_from_impl(
|
|
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
|
)
|
|
|
|
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
|
|
torch.testing.assert_close(ref_out_s, out_s)
|
|
|
|
|
|
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
|
|
timers = []
|
|
for N, T in product(Ns, Ts):
|
|
test_correctness(T, N)
|
|
|
|
bench_tensors: list[BenchmarkTensors] = [
|
|
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
|
|
]
|
|
|
|
silu_mul_quant_timer = bench_impl(
|
|
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
|
)
|
|
timers.append(silu_mul_quant_timer)
|
|
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
|
|
timers.append(reference_timer)
|
|
|
|
print_timers(
|
|
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
|
|
)
|
|
|
|
print_timers(timers, cuda_graph_nops=arg_pool_size)
|
|
|
|
return timers
|
|
|
|
|
|
if __name__ == "__main__":
|
|
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
|
|
N = [2048, 4096, 8192]
|
|
|
|
print(f"T = {T}, N = {N}")
|
|
run(T, N, arg_pool_size=8)
|