[ Kernel ] Enable Dynamic Per Token fp8 (#6547)

This commit is contained in:
Robert Shaw 2024-07-19 19:08:15 -04:00 committed by GitHub
parent 07eb6f19f3
commit 4cc24f01b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 67 additions and 38 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.769
- name: "exact_match,flexible-extract"
value: 0.769
limit: 1000
num_fewshot: 5

View File

@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml

View File

@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
device="cuda") + 1e-6 # avoid nans
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
use_per_token_if_dynamic=True)
assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),

View File

@ -300,6 +300,7 @@ def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
@ -315,6 +316,8 @@ def scaled_fp8_quant(
scale: Optional scaling factor for the FP8 quantization
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
@ -328,24 +331,21 @@ def scaled_fp8_quant(
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
def dynamic_per_token_scaled_fp8_quant(
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
return output, scales
# int8
def scaled_int8_quant(
input: torch.Tensor,

View File

@ -103,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)

View File

@ -214,7 +214,8 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=False)
class Fp8MoEMethod(FusedMoEMethodBase):

View File

@ -107,31 +107,43 @@ def apply_fp8_linear(
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
use_per_token_if_dynamic: bool = False,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
batch_dim_padding=17)
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
batch_dim_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)
if weight_scale.numel() == 1:
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput,
weight,
@ -139,9 +151,11 @@ def apply_fp8_linear(
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return torch.narrow(output, 0, 0, input.shape[0])
else:
# Fallback for channelwise case, where the weight scales are
# applied separately.
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
@ -155,21 +169,21 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# This computes C = sx * (X * W).
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32,
scale_a=x_scale)
out_dtype=torch.float32)
# Unpad (undo batch_dim_padding)
output = torch.narrow(output, 0, 0, input.shape[0])
# C = sw * sx * (X * W)
output = output * weight_scale.t()
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
# C = sw * sx * (X * W) + bias
output = output + bias
output = output.to(dtype=input.dtype)
return torch.narrow(output, 0, 0, input.shape[0])
return output.to(dtype=input.dtype)
def apply_int8_linear(