mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 04:49:37 +08:00
[Kernel] Switch fp8 layers to use the CUTLASS kernels (#5183)
Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8 see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and #5144 for comparisons across different GEMM sizes.
This commit is contained in:
parent
388596c914
commit
8d75fe48ca
@ -179,7 +179,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
||||
a_scales: torch.Tensor, b_scales: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype]) -> torch.Tensor:
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
@ -188,7 +188,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
||||
n = b.shape[1]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
|
||||
vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -17,6 +17,24 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
version = torch.version.cuda
|
||||
version = version[0] * 10 + version[1]
|
||||
|
||||
# CUTLASS FP8 kernels need at least
|
||||
# CUDA 12.0 on SM90 systems (Hopper)
|
||||
# CUDA 12.4 on SM89 systems (Lovelace)
|
||||
gpu_is_supported = False
|
||||
if capability >= 900:
|
||||
gpu_is_supported = version > 120
|
||||
elif capability >= 890:
|
||||
gpu_is_supported = version > 124
|
||||
|
||||
return gpu_is_supported
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@ -92,6 +110,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
def _create_scale_param(
|
||||
self,
|
||||
@ -233,25 +252,40 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
||||
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.act_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
if bias is None and self.cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm_dq(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.act_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user