mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 13:07:21 +08:00
first try
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
a806c14cc7
commit
974e6820ce
@ -10,6 +10,14 @@ from torch.nn import Parameter
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
check_aiter_fp8_linear_support,
|
||||
@ -24,7 +32,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
@ -72,9 +79,32 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
param_name_list = ["weight", "weight_scale", "input_scale"]
|
||||
layer_mapping_function = lambda layer: (
|
||||
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
||||
param_name_list,
|
||||
)
|
||||
|
||||
# TODO: clean up
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=True,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=self.act_q_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
)
|
||||
kernel = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
)
|
||||
self.fp8_linear = kernel(
|
||||
scaled_mm_linear_kernel_config, layer_mapping_function
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -190,11 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -2,16 +2,36 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
|
||||
class ScaledMMLinearQuantStrategy(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
BLOCK = "block"
|
||||
|
||||
def is_per_token(self) -> bool:
|
||||
return self.row == 1 and self.col == -1
|
||||
|
||||
def is_per_group(self) -> bool:
|
||||
return self.row == 1 and self.col >= 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
# TODO: remove is channelwise
|
||||
is_channelwise: bool
|
||||
is_static_input_scheme: bool
|
||||
input_symmetric: bool
|
||||
out_dtype: torch.dtype | None
|
||||
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
@ -26,21 +46,11 @@ class ScaledMMLinearKernel(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: ScaledMMLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
i_s_param_name: str,
|
||||
i_zp_param_name: str,
|
||||
azp_adj_param_name: str,
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
self.i_s_name = i_s_param_name
|
||||
self.i_zp_name = i_zp_param_name
|
||||
self.azp_adj_name = azp_adj_param_name
|
||||
self.layer_mapping_function = layer_mapping_function
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@ -55,19 +65,19 @@ class ScaledMMLinearKernel(ABC):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.i_s_name),
|
||||
getattr(layer, self.i_zp_name),
|
||||
getattr(layer, self.azp_adj_name),
|
||||
)
|
||||
# def _get_weight_params(
|
||||
# self, layer: torch.nn.Module
|
||||
# ) -> tuple[
|
||||
# torch.Tensor, # weight
|
||||
# torch.Tensor, # weight_scale
|
||||
# torch.Tensor | None, # input_scale,
|
||||
# torch.Tensor | None, # input_zp
|
||||
# torch.Tensor | None, # azp_adj
|
||||
# ]:
|
||||
# return (
|
||||
# getattr(layer, self.w_q_name),
|
||||
# getattr(layer, self.w_s_name),
|
||||
# getattr(layer, self.i_s_name),
|
||||
# getattr(layer, self.i_zp_name),
|
||||
# getattr(layer, self.azp_adj_name),
|
||||
# )
|
||||
|
||||
@ -12,10 +12,18 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
ROCmScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
RowWiseTorchScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
TritonScaledMMLinearKernel,
|
||||
)
|
||||
@ -25,16 +33,28 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [
|
||||
ROCmScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
RowWiseTorchScaledMMLinearKernel,
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]],
|
||||
compute_capability: int | None = None,
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
@ -61,7 +81,7 @@ def choose_scaled_mm_linear_kernel(
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
for kernel in possible_kernels[current_platform._enum]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
|
||||
@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from .cutlass import process_weights_after_loading
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_impl(
|
||||
@ -52,7 +52,7 @@ if current_platform.is_rocm():
|
||||
)
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
@ -92,7 +92,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
_, param_names = self.layer_mapping_function(layer)
|
||||
|
||||
process_weights_after_loading(self.config, layer, *param_names)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@ -110,7 +112,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||
ATIER block scaled GEMM and mix-precision GEMM.
|
||||
"""
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
|
||||
@ -2,10 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
@ -14,6 +18,111 @@ from vllm.platforms import current_platform
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def process_weights_after_loading(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
layer: torch.nn.Module,
|
||||
w_q_name: str,
|
||||
w_s_name: str,
|
||||
i_s_name: str,
|
||||
i_zp_name: str,
|
||||
azp_adj_name: str,
|
||||
):
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
if not config.input_symmetric:
|
||||
weight = getattr(layer, w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -27,83 +136,9 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
_, param_names = self.layer_mapping_function(layer)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
process_weights_after_loading(self.config, layer, *param_names)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@ -111,7 +146,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
@ -138,3 +173,70 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return ops.cutlass_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=None,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return (
|
||||
False,
|
||||
"CutlassFP8ScaledMMLinearKernel is supported "
|
||||
+ "on CUDA platforms Only.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
|
||||
return cutlass_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_scaled_fp8_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
|
||||
|
||||
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=None,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
)
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return (
|
||||
False,
|
||||
"FlashInferScaledMMLinearKernel is supported "
|
||||
+ "on CUDA platforms Only.",
|
||||
)
|
||||
|
||||
if not has_flashinfer():
|
||||
return (
|
||||
False,
|
||||
"FlashInferScaledMMLinearKernel requires "
|
||||
+ "FlashInfer to be installed.",
|
||||
)
|
||||
if not has_flashinfer():
|
||||
return (
|
||||
False,
|
||||
"FlashInferScaledMMLinearKernel requires "
|
||||
+ "FlashInfer to be installed.",
|
||||
)
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return (
|
||||
False,
|
||||
"FlashInferScaledMMLinearKernel requires "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
|
||||
return flashinfer_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
A.shape[0] == 1
|
||||
and B.shape[1] % 16 == 0
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
B.t(),
|
||||
A,
|
||||
out_dtype,
|
||||
As,
|
||||
Bs,
|
||||
current_platform.get_cu_count(),
|
||||
bias,
|
||||
)
|
||||
# Fallabck
|
||||
else:
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs,
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list[int],
|
||||
) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
A, B, out_dtype, As, Bs, bias
|
||||
)
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=None,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# TODO: check if this causes an issue on non-ROCM platforms
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
)
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"ROCmScaledMMLinearFP8Kernel is supported " + "on ROCm platforms Only.",
|
||||
)
|
||||
if not on_mi3xx():
|
||||
return (
|
||||
False,
|
||||
"ROCmScaledMMLinearFP8Kernel is supported "
|
||||
+ "on MI3xx architures only.",
|
||||
)
|
||||
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
|
||||
return (
|
||||
False,
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
|
||||
+ "to use ROCmScaledMMLinearKernel ",
|
||||
)
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return (
|
||||
False,
|
||||
"ROCmScaledMMLinearKernel requires "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
|
||||
return rocm_per_tensor_float_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch._scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_row_wise_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
#
|
||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||
# rowwise scaled GEMM before using it
|
||||
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.t(),
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
output = torch.narrow(output, 0, 0, A.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
|
||||
def torch_channelwise_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, A.shape[0])
|
||||
x_scale = torch.narrow(As, 0, 0, A.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * Bs.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
|
||||
|
||||
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
vllm_config = get_current_vllm_config().compilation_config
|
||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||
|
||||
output_padding = 17 if pad_output else None
|
||||
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=output_padding,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
|
||||
|
||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
)
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return (
|
||||
False,
|
||||
"PerTensorTorchScaledMMLinearKernel requires "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
return torch_per_tensor_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
|
||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 94
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
)
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel cannot be used with "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel is only supported "
|
||||
+ "in ROCm platforms.",
|
||||
)
|
||||
|
||||
if not version.parse(torch.__version__) >= version.parse("2.7"):
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
return torch_row_wise_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
|
||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 94
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
)
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
"ChannelWiseTorchScaledMMLinearKernel cannot be used with "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
return torch_channelwise_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user