diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 380431e864355..f2d8eecdc68e5 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,19 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch from packaging import version from vllm import _custom_ops as ops -from vllm import envs -from vllm.config import CompilationMode, get_current_vllm_config -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer -from vllm.utils.torch_utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -143,354 +136,6 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) -def cutlass_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - return output.view(*output_shape) - - -def flashinfer_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - return flashinfer_scaled_fp8_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - - -def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - from vllm.platforms.rocm import on_mi3xx - - if ( - envs.VLLM_ROCM_USE_SKINNY_GEMM - and on_mi3xx() - and qinput.shape[0] == 1 - and qinput.shape[1] % 16 == 0 - and ((bias is None) or (bias.dtype == out_dtype)) - ): - output = ops.wvSplitKQ( - weight.t(), - qinput, - out_dtype, - scale_a, - scale_b, - current_platform.get_cu_count(), - bias, - ) - else: - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias, - ) - return output - - -def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) - - -def rocm_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias - ) - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -direct_register_custom_op( - op_name="rocm_per_tensor_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, -) - - -def torch_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch._scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -def torch_per_token_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM - # when using it. - # For now it has only been validated on ROCm platform. - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using - # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. - # - # For CUDA platform please validate if the torch._scaled_mm supports - # rowwise scaled GEMM before using it - - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b.t(), - bias=bias, - ) - - output = torch.narrow(output, 0, 0, qinput.shape[0]) - output = output.view(*output_shape) - return output - - -def torch_channelwise_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Use unfused DQ due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm( - qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, qinput.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * scale_b.t() - if bias is not None: - output = output + bias - return output.to(out_dtype).view(*output_shape) - - -def dispatch_w8a8_scaled_mm( - preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool -) -> Callable[..., torch.Tensor]: - if per_tensor_weights and per_tensor_activations: - if preferred_backend == "rocm": - return rocm_per_tensor_w8a8_scaled_mm - if preferred_backend == "flashinfer": - return flashinfer_w8a8_scaled_mm - if preferred_backend == "cutlass": - return cutlass_w8a8_scaled_mm - return torch_per_tensor_w8a8_scaled_mm - - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if preferred_backend == "cutlass" or preferred_backend == "flashinfer": - return cutlass_w8a8_scaled_mm - - # If torch.scaled_mm supports per-channel (weights) per-token (inputs) - if ( - not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM - ): - return torch_per_token_w8a8_scaled_mm - # Normally, torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - return torch_channelwise_w8a8_scaled_mm - - -# TODO(luka): follow similar pattern for marlin and block-fp8-linear -# https://github.com/vllm-project/vllm/issues/14397 -class Fp8LinearOp: - """ - This class executes a FP8 linear layer using cutlass if supported and - torch.scaled_mm otherwise. - It needs to be a class instead of a method so that config can be read - in the __init__ method, as reading config is not allowed inside forward. - """ - - def __init__( - self, - act_quant_static: bool, - act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: bool | None = None, - ): - if current_platform.is_rocm(): - self.preferred_backend = "rocm" - elif current_platform.is_cuda() and cutlass_fp8_supported(): - if has_flashinfer() and current_platform.has_device_capability(100): - self.preferred_backend = "flashinfer" - else: - self.preferred_backend = "cutlass" - else: - self.preferred_backend = "torch" - - # Note: we pad the input because torch._scaled_mm is more performant - # for matrices with batch dimension > 16. - # This could change in the future. - # We also don't pad when using torch.compile, - # as it breaks with dynamic shapes. - if pad_output is None: - config = get_current_vllm_config().compilation_config - pad_output = ( - config.mode < CompilationMode.VLLM_COMPILE - and self.preferred_backend == "torch" - ) - - self.output_padding = 17 if pad_output else None - self.act_quant_static = act_quant_static - self.act_quant_group_shape = act_quant_group_shape - self.quant_fp8 = QuantFP8( - static=act_quant_static, - group_shape=act_quant_group_shape, - num_token_padding=self.output_padding, - ) - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - out_dtype: torch.dtype | None = None, - input_scale: torch.Tensor | None = None, - input_scale_ub: torch.Tensor | None = None, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[1]] - - if out_dtype is None: - out_dtype = input.dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - if input.dtype != current_platform.fp8_dtype(): - qinput, x_scale = self.quant_fp8( - input_2d, - input_scale, - input_scale_ub, - ) - else: - qinput, x_scale = input_2d, input_scale - - # Must have dim() conditions - # In per-token quant scenario, when the number of token is 1, - # the scale will only have 1 elements. - # Without checking the dim(), - # we cannot distingushes between per-tensor and per-token quant. - # Example: - # When the number of token is 1, per-token scale is [[1]] - # When per-tensor scale is [1] or (). - per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 - per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 - - # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.preferred_backend, per_tensor_weights, per_tensor_activations - ) - - return w8a8_scaled_mm_func( - qinput=qinput, - weight=weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - output_shape=output_shape, - ) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor,