mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 06:17:03 +08:00
remove FP8LinearOps
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
52ff537459
commit
b13c4bb25c
@ -1,19 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm import envs
|
|
||||||
from vllm.config import CompilationMode, get_current_vllm_config
|
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
@ -143,354 +136,6 @@ def maybe_create_device_identity():
|
|||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_w8a8_scaled_mm(
|
|
||||||
*,
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
output_shape: list,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Fused GEMM_DQ
|
|
||||||
output = ops.cutlass_scaled_mm(
|
|
||||||
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
|
|
||||||
)
|
|
||||||
return output.view(*output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_w8a8_scaled_mm(
|
|
||||||
*,
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
output_shape: list,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return flashinfer_scaled_fp8_mm(
|
|
||||||
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
from vllm.platforms.rocm import on_mi3xx
|
|
||||||
|
|
||||||
if (
|
|
||||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
|
||||||
and on_mi3xx()
|
|
||||||
and qinput.shape[0] == 1
|
|
||||||
and qinput.shape[1] % 16 == 0
|
|
||||||
and ((bias is None) or (bias.dtype == out_dtype))
|
|
||||||
):
|
|
||||||
output = ops.wvSplitKQ(
|
|
||||||
weight.t(),
|
|
||||||
qinput,
|
|
||||||
out_dtype,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
current_platform.get_cu_count(),
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_per_tensor_w8a8_scaled_mm_fake(
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_per_tensor_w8a8_scaled_mm(
|
|
||||||
*,
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
output_shape: list,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
|
||||||
qinput, weight, out_dtype, scale_a, scale_b, bias
|
|
||||||
)
|
|
||||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
|
||||||
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
|
||||||
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_per_tensor_w8a8_scaled_mm(
|
|
||||||
*,
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
output_shape: list,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
|
|
||||||
)
|
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_per_token_w8a8_scaled_mm(
|
|
||||||
*,
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
output_shape: list,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
|
||||||
# when using it.
|
|
||||||
# For now it has only been validated on ROCm platform.
|
|
||||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
|
||||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
|
||||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
|
||||||
#
|
|
||||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
|
||||||
# rowwise scaled GEMM before using it
|
|
||||||
|
|
||||||
# Fused GEMM_DQ Rowwise GEMM
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b.t(),
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
|
||||||
output = output.view(*output_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def torch_channelwise_w8a8_scaled_mm(
|
|
||||||
*,
|
|
||||||
qinput: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
output_shape: list,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Use unfused DQ due to limitations with scaled_mm
|
|
||||||
|
|
||||||
# Symmetric quantized GEMM by definition computes the following:
|
|
||||||
# C = (s_x * X) (s_w * W) + bias
|
|
||||||
# This is equivalent to dequantizing the weights and activations
|
|
||||||
# before applying a GEMM.
|
|
||||||
#
|
|
||||||
# In order to compute quantized operands, a quantized kernel
|
|
||||||
# will rewrite the above like so:
|
|
||||||
# C = s_w * s_x * (X * W) + bias
|
|
||||||
#
|
|
||||||
# For the scaled_mm fallback case, we break this down, since it
|
|
||||||
# does not support s_w being a vector.
|
|
||||||
|
|
||||||
# GEMM
|
|
||||||
# This computes C = (X * W).
|
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
scale_a=TORCH_DEVICE_IDENTITY,
|
|
||||||
scale_b=TORCH_DEVICE_IDENTITY,
|
|
||||||
out_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
# Unpad (undo num_token_padding)
|
|
||||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
|
||||||
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
|
|
||||||
|
|
||||||
# DQ
|
|
||||||
# C = sw * sx * (X * W) + bias
|
|
||||||
output = output * x_scale * scale_b.t()
|
|
||||||
if bias is not None:
|
|
||||||
output = output + bias
|
|
||||||
return output.to(out_dtype).view(*output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def dispatch_w8a8_scaled_mm(
|
|
||||||
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool
|
|
||||||
) -> Callable[..., torch.Tensor]:
|
|
||||||
if per_tensor_weights and per_tensor_activations:
|
|
||||||
if preferred_backend == "rocm":
|
|
||||||
return rocm_per_tensor_w8a8_scaled_mm
|
|
||||||
if preferred_backend == "flashinfer":
|
|
||||||
return flashinfer_w8a8_scaled_mm
|
|
||||||
if preferred_backend == "cutlass":
|
|
||||||
return cutlass_w8a8_scaled_mm
|
|
||||||
return torch_per_tensor_w8a8_scaled_mm
|
|
||||||
|
|
||||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
|
||||||
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
|
|
||||||
return cutlass_w8a8_scaled_mm
|
|
||||||
|
|
||||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
|
||||||
if (
|
|
||||||
not per_tensor_weights
|
|
||||||
and not per_tensor_activations
|
|
||||||
and USE_ROWWISE_TORCH_SCALED_MM
|
|
||||||
):
|
|
||||||
return torch_per_token_w8a8_scaled_mm
|
|
||||||
# Normally, torch.scaled_mm supports per tensor weights + activations only
|
|
||||||
# so fallback to naive if per channel or per token
|
|
||||||
return torch_channelwise_w8a8_scaled_mm
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
|
||||||
class Fp8LinearOp:
|
|
||||||
"""
|
|
||||||
This class executes a FP8 linear layer using cutlass if supported and
|
|
||||||
torch.scaled_mm otherwise.
|
|
||||||
It needs to be a class instead of a method so that config can be read
|
|
||||||
in the __init__ method, as reading config is not allowed inside forward.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
act_quant_static: bool,
|
|
||||||
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
|
|
||||||
pad_output: bool | None = None,
|
|
||||||
):
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
self.preferred_backend = "rocm"
|
|
||||||
elif current_platform.is_cuda() and cutlass_fp8_supported():
|
|
||||||
if has_flashinfer() and current_platform.has_device_capability(100):
|
|
||||||
self.preferred_backend = "flashinfer"
|
|
||||||
else:
|
|
||||||
self.preferred_backend = "cutlass"
|
|
||||||
else:
|
|
||||||
self.preferred_backend = "torch"
|
|
||||||
|
|
||||||
# Note: we pad the input because torch._scaled_mm is more performant
|
|
||||||
# for matrices with batch dimension > 16.
|
|
||||||
# This could change in the future.
|
|
||||||
# We also don't pad when using torch.compile,
|
|
||||||
# as it breaks with dynamic shapes.
|
|
||||||
if pad_output is None:
|
|
||||||
config = get_current_vllm_config().compilation_config
|
|
||||||
pad_output = (
|
|
||||||
config.mode < CompilationMode.VLLM_COMPILE
|
|
||||||
and self.preferred_backend == "torch"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_padding = 17 if pad_output else None
|
|
||||||
self.act_quant_static = act_quant_static
|
|
||||||
self.act_quant_group_shape = act_quant_group_shape
|
|
||||||
self.quant_fp8 = QuantFP8(
|
|
||||||
static=act_quant_static,
|
|
||||||
group_shape=act_quant_group_shape,
|
|
||||||
num_token_padding=self.output_padding,
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype | None = None,
|
|
||||||
input_scale: torch.Tensor | None = None,
|
|
||||||
input_scale_ub: torch.Tensor | None = None,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
|
||||||
|
|
||||||
# View input as 2D matrix for fp8 methods
|
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
|
||||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
|
||||||
|
|
||||||
if out_dtype is None:
|
|
||||||
out_dtype = input.dtype
|
|
||||||
|
|
||||||
# If input not quantized
|
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
if input.dtype != current_platform.fp8_dtype():
|
|
||||||
qinput, x_scale = self.quant_fp8(
|
|
||||||
input_2d,
|
|
||||||
input_scale,
|
|
||||||
input_scale_ub,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qinput, x_scale = input_2d, input_scale
|
|
||||||
|
|
||||||
# Must have dim() conditions
|
|
||||||
# In per-token quant scenario, when the number of token is 1,
|
|
||||||
# the scale will only have 1 elements.
|
|
||||||
# Without checking the dim(),
|
|
||||||
# we cannot distingushes between per-tensor and per-token quant.
|
|
||||||
# Example:
|
|
||||||
# When the number of token is 1, per-token scale is [[1]]
|
|
||||||
# When per-tensor scale is [1] or ().
|
|
||||||
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
|
|
||||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
|
||||||
|
|
||||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
|
||||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
|
||||||
self.preferred_backend, per_tensor_weights, per_tensor_activations
|
|
||||||
)
|
|
||||||
|
|
||||||
return w8a8_scaled_mm_func(
|
|
||||||
qinput=qinput,
|
|
||||||
weight=weight,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=weight_scale,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user