mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
[ Kernel ] Enable Dynamic Per Token fp8 (#6547)
This commit is contained in:
parent
07eb6f19f3
commit
4cc24f01b1
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.769
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.769
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
|
||||
@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
|
||||
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||
|
||||
@ -300,6 +300,7 @@ def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
batch_dim_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
@ -315,6 +316,8 @@ def scaled_fp8_quant(
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
batch_dim_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||
@ -328,24 +331,21 @@ def scaled_fp8_quant(
|
||||
else:
|
||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||
output, input, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
def dynamic_per_token_scaled_fp8_quant(
|
||||
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||
scales = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
|
||||
return output, scales
|
||||
|
||||
|
||||
# int8
|
||||
def scaled_int8_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
@ -103,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
@ -214,7 +214,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
use_per_token_if_dynamic=False)
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@ -107,31 +107,43 @@ def apply_fp8_linear(
|
||||
input_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_fp8_supported: bool = True,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input,
|
||||
input_scale,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
return ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
else:
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
qinput, x_scale = ops.scaled_fp8_quant(input,
|
||||
input_scale,
|
||||
batch_dim_padding=17)
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input,
|
||||
input_scale,
|
||||
batch_dim_padding=17,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
|
||||
if weight_scale.numel() == 1:
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
# Fused GEMM_DQ
|
||||
output, _ = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
@ -139,9 +151,11 @@ def apply_fp8_linear(
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
return torch.narrow(output, 0, 0, input.shape[0])
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where the weight scales are
|
||||
# applied separately.
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
# due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
@ -155,21 +169,21 @@ def apply_fp8_linear(
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# This computes C = sx * (X * W).
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output, _ = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=torch.float32,
|
||||
scale_a=x_scale)
|
||||
out_dtype=torch.float32)
|
||||
# Unpad (undo batch_dim_padding)
|
||||
output = torch.narrow(output, 0, 0, input.shape[0])
|
||||
|
||||
# C = sw * sx * (X * W)
|
||||
output = output * weight_scale.t()
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * weight_scale.t()
|
||||
if bias is not None:
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output + bias
|
||||
output = output.to(dtype=input.dtype)
|
||||
|
||||
return torch.narrow(output, 0, 0, input.shape[0])
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
def apply_int8_linear(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user