mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 01:42:14 +08:00
first try
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
a806c14cc7
commit
974e6820ce
@ -10,6 +10,14 @@ from torch.nn import Parameter
|
|||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
|
_POSSIBLE_FP8_KERNELS,
|
||||||
|
choose_scaled_mm_linear_kernel,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||||
|
ScaledMMLinearLayerConfig,
|
||||||
|
ScaledMMLinearQuantStrategy,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
W8A8BlockFp8LinearOp,
|
||||||
check_aiter_fp8_linear_support,
|
check_aiter_fp8_linear_support,
|
||||||
@ -24,7 +32,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp,
|
|
||||||
cutlass_block_fp8_supported,
|
cutlass_block_fp8_supported,
|
||||||
maybe_create_device_identity,
|
maybe_create_device_identity,
|
||||||
)
|
)
|
||||||
@ -72,9 +79,32 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.fp8_linear = Fp8LinearOp(
|
param_name_list = ["weight", "weight_scale", "input_scale"]
|
||||||
act_quant_static=self.is_static_input_scheme,
|
layer_mapping_function = lambda layer: (
|
||||||
act_quant_group_shape=self.act_q_group_shape,
|
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
||||||
|
param_name_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: clean up
|
||||||
|
if self.strategy == QuantizationStrategy.TENSOR:
|
||||||
|
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
||||||
|
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
||||||
|
|
||||||
|
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||||
|
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||||
|
is_static_input_scheme=self.is_static_input_scheme,
|
||||||
|
input_symmetric=True,
|
||||||
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
|
activation_group_shape=self.act_q_group_shape,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
)
|
||||||
|
kernel = choose_scaled_mm_linear_kernel(
|
||||||
|
scaled_mm_linear_kernel_config,
|
||||||
|
_POSSIBLE_FP8_KERNELS,
|
||||||
|
)
|
||||||
|
self.fp8_linear = kernel(
|
||||||
|
scaled_mm_linear_kernel_config, layer_mapping_function
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -190,11 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.fp8_linear.apply(
|
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
out_dtype=self.out_dtype,
|
|
||||||
input_scale=layer.input_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -2,16 +2,36 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledMMLinearQuantStrategy(Enum):
|
||||||
|
TENSOR = "tensor"
|
||||||
|
CHANNEL = "channel"
|
||||||
|
BLOCK = "block"
|
||||||
|
|
||||||
|
def is_per_token(self) -> bool:
|
||||||
|
return self.row == 1 and self.col == -1
|
||||||
|
|
||||||
|
def is_per_group(self) -> bool:
|
||||||
|
return self.row == 1 and self.col >= 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScaledMMLinearLayerConfig:
|
class ScaledMMLinearLayerConfig:
|
||||||
|
# TODO: remove is channelwise
|
||||||
is_channelwise: bool
|
is_channelwise: bool
|
||||||
is_static_input_scheme: bool
|
is_static_input_scheme: bool
|
||||||
input_symmetric: bool
|
input_symmetric: bool
|
||||||
|
out_dtype: torch.dtype | None
|
||||||
|
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||||
|
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR
|
||||||
|
|
||||||
|
|
||||||
class ScaledMMLinearKernel(ABC):
|
class ScaledMMLinearKernel(ABC):
|
||||||
@ -26,21 +46,11 @@ class ScaledMMLinearKernel(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
c: ScaledMMLinearLayerConfig,
|
|
||||||
w_q_param_name: str,
|
|
||||||
w_s_param_name: str,
|
|
||||||
i_s_param_name: str,
|
|
||||||
i_zp_param_name: str,
|
|
||||||
azp_adj_param_name: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.can_implement(c)
|
assert self.can_implement(c)
|
||||||
self.config = c
|
self.config = c
|
||||||
self.w_q_name = w_q_param_name
|
self.layer_mapping_function = layer_mapping_function
|
||||||
self.w_s_name = w_s_param_name
|
|
||||||
self.i_s_name = i_s_param_name
|
|
||||||
self.i_zp_name = i_zp_param_name
|
|
||||||
self.azp_adj_name = azp_adj_param_name
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
@ -55,19 +65,19 @@ class ScaledMMLinearKernel(ABC):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_weight_params(
|
# def _get_weight_params(
|
||||||
self, layer: torch.nn.Module
|
# self, layer: torch.nn.Module
|
||||||
) -> tuple[
|
# ) -> tuple[
|
||||||
torch.Tensor, # weight
|
# torch.Tensor, # weight
|
||||||
torch.Tensor, # weight_scale
|
# torch.Tensor, # weight_scale
|
||||||
torch.Tensor | None, # input_scale,
|
# torch.Tensor | None, # input_scale,
|
||||||
torch.Tensor | None, # input_zp
|
# torch.Tensor | None, # input_zp
|
||||||
torch.Tensor | None, # azp_adj
|
# torch.Tensor | None, # azp_adj
|
||||||
]:
|
# ]:
|
||||||
return (
|
# return (
|
||||||
getattr(layer, self.w_q_name),
|
# getattr(layer, self.w_q_name),
|
||||||
getattr(layer, self.w_s_name),
|
# getattr(layer, self.w_s_name),
|
||||||
getattr(layer, self.i_s_name),
|
# getattr(layer, self.i_s_name),
|
||||||
getattr(layer, self.i_zp_name),
|
# getattr(layer, self.i_zp_name),
|
||||||
getattr(layer, self.azp_adj_name),
|
# getattr(layer, self.azp_adj_name),
|
||||||
)
|
# )
|
||||||
|
|||||||
@ -12,10 +12,18 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
|||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||||
CutlassScaledMMLinearKernel,
|
CutlassScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||||
|
ROCmScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||||
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
|
PerTensorTorchScaledMMLinearKernel,
|
||||||
|
RowWiseTorchScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||||
TritonScaledMMLinearKernel,
|
TritonScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
@ -25,16 +33,28 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
|||||||
from vllm.platforms import PlatformEnum, current_platform
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
|
||||||
# in priority/performance order (when available)
|
# in priority/performance order (when available)
|
||||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||||
|
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||||
|
PlatformEnum.ROCM: [
|
||||||
|
ROCmScaledMMLinearKernel,
|
||||||
|
PerTensorTorchScaledMMLinearKernel,
|
||||||
|
RowWiseTorchScaledMMLinearKernel,
|
||||||
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def choose_scaled_mm_linear_kernel(
|
def choose_scaled_mm_linear_kernel(
|
||||||
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
|
config: ScaledMMLinearLayerConfig,
|
||||||
|
possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]],
|
||||||
|
compute_capability: int | None = None,
|
||||||
) -> type[ScaledMMLinearKernel]:
|
) -> type[ScaledMMLinearKernel]:
|
||||||
"""
|
"""
|
||||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||||
@ -61,7 +81,7 @@ def choose_scaled_mm_linear_kernel(
|
|||||||
compute_capability = _cc[0] * 10 + _cc[1]
|
compute_capability = _cc[0] * 10 + _cc[1]
|
||||||
|
|
||||||
failure_reasons = []
|
failure_reasons = []
|
||||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
for kernel in possible_kernels[current_platform._enum]:
|
||||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||||
failure_reasons.append(
|
failure_reasons.append(
|
||||||
f" {kernel.__name__} disabled by environment variable"
|
f" {kernel.__name__} disabled by environment variable"
|
||||||
|
|||||||
@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
from .cutlass import CutlassScaledMMLinearKernel
|
from .cutlass import process_weights_after_loading
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_gemm_w8a8_impl(
|
def rocm_aiter_gemm_w8a8_impl(
|
||||||
@ -52,7 +52,7 @@ if current_platform.is_rocm():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 90
|
return 90
|
||||||
@ -92,7 +92,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
super().process_weights_after_loading(layer)
|
_, param_names = self.layer_mapping_function(layer)
|
||||||
|
|
||||||
|
process_weights_after_loading(self.config, layer, *param_names)
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
@ -110,7 +112,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||||
ATIER block scaled GEMM and mix-precision GEMM.
|
ATIER block scaled GEMM and mix-precision GEMM.
|
||||||
"""
|
"""
|
||||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
||||||
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
# * dynamic, i_s is None and x_s computed from x.
|
# * dynamic, i_s is None and x_s computed from x.
|
||||||
|
|||||||
@ -2,10 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
)
|
)
|
||||||
@ -14,6 +18,111 @@ from vllm.platforms import current_platform
|
|||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_w8a8_scaled_mm(
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
output_shape: list,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Fused GEMM_DQ
|
||||||
|
output = ops.cutlass_scaled_mm(
|
||||||
|
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||||
|
)
|
||||||
|
return output.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def process_weights_after_loading(
|
||||||
|
config: ScaledMMLinearLayerConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
w_q_name: str,
|
||||||
|
w_s_name: str,
|
||||||
|
i_s_name: str,
|
||||||
|
i_zp_name: str,
|
||||||
|
azp_adj_name: str,
|
||||||
|
):
|
||||||
|
# WEIGHT
|
||||||
|
# Cutlass kernels need transposed weight.
|
||||||
|
weight = getattr(layer, w_q_name)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
w_q_name,
|
||||||
|
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
# Cutlass kernels support only per-tensor and per-channel.
|
||||||
|
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
|
# scales being passed to the kernel), convert to the per-channel case.
|
||||||
|
is_fused_module = len(layer.logical_widths) > 1
|
||||||
|
weight_scale = getattr(layer, w_s_name)
|
||||||
|
if is_fused_module and not config.is_channelwise:
|
||||||
|
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
w_s_name,
|
||||||
|
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if config.is_static_input_scheme:
|
||||||
|
input_scale = getattr(layer, i_s_name)
|
||||||
|
|
||||||
|
if config.input_symmetric:
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
i_s_name,
|
||||||
|
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||||
|
)
|
||||||
|
setattr(layer, i_zp_name, None)
|
||||||
|
else:
|
||||||
|
input_zero_point = getattr(layer, i_zp_name)
|
||||||
|
|
||||||
|
# reconstruct the ranges
|
||||||
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
|
azps = input_zero_point.to(dtype=torch.int32)
|
||||||
|
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||||
|
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||||
|
|
||||||
|
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||||
|
replace_parameter(
|
||||||
|
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# AZP loaded as int8 but used as int32
|
||||||
|
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||||
|
replace_parameter(
|
||||||
|
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
setattr(layer, i_s_name, None)
|
||||||
|
setattr(layer, i_zp_name, None)
|
||||||
|
|
||||||
|
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||||
|
# It does not depend on scales or azp, so it is the same for
|
||||||
|
# static and dynamic quantization.
|
||||||
|
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||||
|
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||||
|
if not config.input_symmetric:
|
||||||
|
weight = getattr(layer, w_q_name)
|
||||||
|
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||||
|
if config.is_static_input_scheme:
|
||||||
|
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||||
|
# in the per-tensor case
|
||||||
|
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
azp_adj_name,
|
||||||
|
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setattr(layer, azp_adj_name, None)
|
||||||
|
|
||||||
|
|
||||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -27,83 +136,9 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# WEIGHT
|
_, param_names = self.layer_mapping_function(layer)
|
||||||
# Cutlass kernels need transposed weight.
|
|
||||||
weight = getattr(layer, self.w_q_name)
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
self.w_q_name,
|
|
||||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# WEIGHT SCALE
|
process_weights_after_loading(self.config, layer, *param_names)
|
||||||
# Cutlass kernels support only per-tensor and per-channel.
|
|
||||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
|
||||||
# scales being passed to the kernel), convert to the per-channel case.
|
|
||||||
is_fused_module = len(layer.logical_widths) > 1
|
|
||||||
weight_scale = getattr(layer, self.w_s_name)
|
|
||||||
if is_fused_module and not self.config.is_channelwise:
|
|
||||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
self.w_s_name,
|
|
||||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# INPUT SCALE
|
|
||||||
if self.config.is_static_input_scheme:
|
|
||||||
input_scale = getattr(layer, self.i_s_name)
|
|
||||||
|
|
||||||
if self.config.input_symmetric:
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
self.i_s_name,
|
|
||||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
|
||||||
)
|
|
||||||
setattr(layer, self.i_zp_name, None)
|
|
||||||
else:
|
|
||||||
input_zero_point = getattr(layer, self.i_zp_name)
|
|
||||||
|
|
||||||
# reconstruct the ranges
|
|
||||||
int8_traits = torch.iinfo(torch.int8)
|
|
||||||
azps = input_zero_point.to(dtype=torch.int32)
|
|
||||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
|
||||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
|
||||||
|
|
||||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
|
||||||
replace_parameter(
|
|
||||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# AZP loaded as int8 but used as int32
|
|
||||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
|
||||||
replace_parameter(
|
|
||||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
setattr(layer, self.i_s_name, None)
|
|
||||||
setattr(layer, self.i_zp_name, None)
|
|
||||||
|
|
||||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
|
||||||
# It does not depend on scales or azp, so it is the same for
|
|
||||||
# static and dynamic quantization.
|
|
||||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
|
||||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
|
||||||
if not self.config.input_symmetric:
|
|
||||||
weight = getattr(layer, self.w_q_name)
|
|
||||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
|
||||||
if self.config.is_static_input_scheme:
|
|
||||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
|
||||||
# in the per-tensor case
|
|
||||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
self.azp_adj_name,
|
|
||||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
setattr(layer, self.azp_adj_name, None)
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
@ -111,7 +146,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
||||||
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
# * dynamic, i_s is None and x_s computed from x.
|
# * dynamic, i_s is None and x_s computed from x.
|
||||||
@ -138,3 +173,70 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return ops.cutlass_scaled_mm(
|
return ops.cutlass_scaled_mm(
|
||||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
|
def __init__(
|
||||||
|
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
|
) -> None:
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=c.is_static_input_scheme,
|
||||||
|
group_shape=GroupShape.PER_TENSOR,
|
||||||
|
num_token_padding=None,
|
||||||
|
)
|
||||||
|
super().__init__(c, layer_mapping_function)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# lovelace and up
|
||||||
|
return 89
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"CutlassFP8ScaledMMLinearKernel is supported "
|
||||||
|
+ "on CUDA platforms Only.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != current_platform.fp8_dtype():
|
||||||
|
x_2d_q, x_s = self.quant_fp8(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||||
|
|
||||||
|
return cutlass_w8a8_scaled_mm(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
|
|||||||
@ -0,0 +1,120 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||||
|
|
||||||
|
from .ScaledMMLinearKernel import (
|
||||||
|
ScaledMMLinearKernel,
|
||||||
|
ScaledMMLinearLayerConfig,
|
||||||
|
ScaledMMLinearQuantStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_w8a8_scaled_mm(
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return flashinfer_scaled_fp8_mm(
|
||||||
|
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
|
def __init__(
|
||||||
|
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
|
) -> None:
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=c.is_static_input_scheme,
|
||||||
|
group_shape=GroupShape.PER_TENSOR,
|
||||||
|
num_token_padding=None,
|
||||||
|
)
|
||||||
|
super().__init__(c, layer_mapping_function)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 100
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
|
per_tensor_weight_scales = (
|
||||||
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
|
)
|
||||||
|
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"FlashInferScaledMMLinearKernel is supported "
|
||||||
|
+ "on CUDA platforms Only.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_flashinfer():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"FlashInferScaledMMLinearKernel requires "
|
||||||
|
+ "FlashInfer to be installed.",
|
||||||
|
)
|
||||||
|
if not has_flashinfer():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"FlashInferScaledMMLinearKernel requires "
|
||||||
|
+ "FlashInfer to be installed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"FlashInferScaledMMLinearKernel requires "
|
||||||
|
+ "per tensor activation and weight scales.",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != current_platform.fp8_dtype():
|
||||||
|
x_2d_q, x_s = self.quant_fp8(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||||
|
|
||||||
|
return flashinfer_w8a8_scaled_mm(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
@ -0,0 +1,179 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .ScaledMMLinearKernel import (
|
||||||
|
ScaledMMLinearKernel,
|
||||||
|
ScaledMMLinearLayerConfig,
|
||||||
|
ScaledMMLinearQuantStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if (
|
||||||
|
A.shape[0] == 1
|
||||||
|
and B.shape[1] % 16 == 0
|
||||||
|
and ((bias is None) or (bias.dtype == out_dtype))
|
||||||
|
):
|
||||||
|
output = ops.wvSplitKQ(
|
||||||
|
B.t(),
|
||||||
|
A,
|
||||||
|
out_dtype,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
current_platform.get_cu_count(),
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
# Fallabck
|
||||||
|
else:
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
scale_a=As,
|
||||||
|
scale_b=Bs,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_per_tensor_float_w8a8_scaled_mm(
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
output_shape: list[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||||
|
A, B, out_dtype, As, Bs, bias
|
||||||
|
)
|
||||||
|
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||||
|
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||||
|
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
|
def __init__(
|
||||||
|
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
|
) -> None:
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=c.is_static_input_scheme,
|
||||||
|
group_shape=GroupShape.PER_TENSOR,
|
||||||
|
num_token_padding=None,
|
||||||
|
)
|
||||||
|
super().__init__(c, layer_mapping_function)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 90
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
# TODO: check if this causes an issue on non-ROCM platforms
|
||||||
|
from vllm.platforms.rocm import on_mi3xx
|
||||||
|
|
||||||
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
|
per_tensor_weight_scales = (
|
||||||
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
|
)
|
||||||
|
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"ROCmScaledMMLinearFP8Kernel is supported " + "on ROCm platforms Only.",
|
||||||
|
)
|
||||||
|
if not on_mi3xx():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"ROCmScaledMMLinearFP8Kernel is supported "
|
||||||
|
+ "on MI3xx architures only.",
|
||||||
|
)
|
||||||
|
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
|
||||||
|
+ "to use ROCmScaledMMLinearKernel ",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"ROCmScaledMMLinearKernel requires "
|
||||||
|
+ "per tensor activation and weight scales.",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != current_platform.fp8_dtype():
|
||||||
|
x_2d_q, x_s = self.quant_fp8(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||||
|
|
||||||
|
return rocm_per_tensor_float_w8a8_scaled_mm(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
@ -0,0 +1,343 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.config import CompilationMode, get_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .ScaledMMLinearKernel import (
|
||||||
|
ScaledMMLinearKernel,
|
||||||
|
ScaledMMLinearLayerConfig,
|
||||||
|
ScaledMMLinearQuantStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
|
TORCH_DEVICE_IDENTITY = None
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_create_device_identity():
|
||||||
|
# Allocate dummy ones tensor for torch._scaled_mm
|
||||||
|
global TORCH_DEVICE_IDENTITY
|
||||||
|
if TORCH_DEVICE_IDENTITY is None:
|
||||||
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_per_tensor_w8a8_scaled_mm(
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
output_shape: list,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||||
|
)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_row_wise_w8a8_scaled_mm(
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
output_shape: list,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||||
|
# when using it.
|
||||||
|
# For now it has only been validated on ROCm platform.
|
||||||
|
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||||
|
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||||
|
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||||
|
#
|
||||||
|
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||||
|
# rowwise scaled GEMM before using it
|
||||||
|
|
||||||
|
# Fused GEMM_DQ Rowwise GEMM
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
scale_a=As,
|
||||||
|
scale_b=Bs.t(),
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.narrow(output, 0, 0, A.shape[0])
|
||||||
|
output = output.view(*output_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def torch_channelwise_w8a8_scaled_mm(
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
output_shape: list,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Use unfused DQ due to limitations with scaled_mm
|
||||||
|
|
||||||
|
# Symmetric quantized GEMM by definition computes the following:
|
||||||
|
# C = (s_x * X) (s_w * W) + bias
|
||||||
|
# This is equivalent to dequantizing the weights and activations
|
||||||
|
# before applying a GEMM.
|
||||||
|
#
|
||||||
|
# In order to compute quantized operands, a quantized kernel
|
||||||
|
# will rewrite the above like so:
|
||||||
|
# C = s_w * s_x * (X * W) + bias
|
||||||
|
#
|
||||||
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
|
# GEMM
|
||||||
|
# This computes C = (X * W).
|
||||||
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
scale_a=TORCH_DEVICE_IDENTITY,
|
||||||
|
scale_b=TORCH_DEVICE_IDENTITY,
|
||||||
|
out_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
# Unpad (undo num_token_padding)
|
||||||
|
output = torch.narrow(output, 0, 0, A.shape[0])
|
||||||
|
x_scale = torch.narrow(As, 0, 0, A.shape[0])
|
||||||
|
|
||||||
|
# DQ
|
||||||
|
# C = sw * sx * (X * W) + bias
|
||||||
|
output = output * x_scale * Bs.t()
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(out_dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
|
def __init__(
|
||||||
|
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
|
) -> None:
|
||||||
|
vllm_config = get_current_vllm_config().compilation_config
|
||||||
|
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
|
output_padding = 17 if pad_output else None
|
||||||
|
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=c.is_static_input_scheme,
|
||||||
|
group_shape=GroupShape.PER_TENSOR,
|
||||||
|
num_token_padding=output_padding,
|
||||||
|
)
|
||||||
|
super().__init__(c, layer_mapping_function)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# lovelace and up
|
||||||
|
return 89
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
assert c.activation_group_shape is not None
|
||||||
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
|
per_tensor_weight_scales = (
|
||||||
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"PerTensorTorchScaledMMLinearKernel requires "
|
||||||
|
+ "per tensor activation and weight scales.",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||||
|
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != current_platform.fp8_dtype():
|
||||||
|
x_2d_q, x_s = self.quant_fp8(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
)
|
||||||
|
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||||
|
return torch_per_tensor_w8a8_scaled_mm(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 94
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
assert c.activation_group_shape is not None
|
||||||
|
|
||||||
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
|
per_tensor_weight_scales = (
|
||||||
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
|
)
|
||||||
|
|
||||||
|
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"RowWiseTorchScaledMMLinearKernel cannot be used with "
|
||||||
|
+ "per tensor activation and weight scales.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"RowWiseTorchScaledMMLinearKernel is only supported "
|
||||||
|
+ "in ROCm platforms.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not version.parse(torch.__version__) >= version.parse("2.7"):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||||
|
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != current_platform.fp8_dtype():
|
||||||
|
x_2d_q, x_s = self.quant_fp8(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
)
|
||||||
|
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||||
|
return torch_row_wise_w8a8_scaled_mm(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 94
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
assert c.activation_group_shape is not None
|
||||||
|
|
||||||
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
|
per_tensor_weight_scales = (
|
||||||
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
|
)
|
||||||
|
|
||||||
|
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"ChannelWiseTorchScaledMMLinearKernel cannot be used with "
|
||||||
|
+ "per tensor activation and weight scales.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||||
|
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != current_platform.fp8_dtype():
|
||||||
|
x_2d_q, x_s = self.quant_fp8(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
)
|
||||||
|
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||||
|
return torch_channelwise_w8a8_scaled_mm(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user